parent
0bae71207b
commit
fee64cb1ac
@ -0,0 +1,115 @@
|
||||
package cn.teammodel.ai;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import cn.teammodel.ai.domain.SparkChatRequestParam;
|
||||
import cn.teammodel.ai.listener.SparkGptStreamListener;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.HttpUrl;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.crypto.Mac;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* spark Gpt client
|
||||
* @author winter
|
||||
* @create 2023-12-15 14:29
|
||||
*/
|
||||
@Component
|
||||
@Data
|
||||
@Slf4j
|
||||
public class SparkGptClient implements InitializingBean {
|
||||
@Resource
|
||||
private SparkGptProperties sparkGptProperties;
|
||||
private OkHttpClient okHttpClient;
|
||||
private String authUrl;
|
||||
|
||||
/**
|
||||
* 静态构造对象方法
|
||||
*/
|
||||
public void init() {
|
||||
String authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
|
||||
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
|
||||
log.info("鉴权 url: {}", this.authUrl);
|
||||
|
||||
this.okHttpClient = new OkHttpClient()
|
||||
.newBuilder()
|
||||
.connectTimeout(90, TimeUnit.SECONDS)
|
||||
.readTimeout(90, TimeUnit.SECONDS)
|
||||
.writeTimeout(90, TimeUnit.SECONDS)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式的生成结果,以 sse 的方式向客户端推送
|
||||
*/
|
||||
public void streamChatCompletion(SparkChatRequestParam param, SparkGptStreamListener listener) {
|
||||
try {
|
||||
param.setAppId(sparkGptProperties.getAppId());
|
||||
Request request = new Request.Builder().url(authUrl).build();
|
||||
// 设置请求参数
|
||||
listener.setRequestJson(param.toJsonParams());
|
||||
log.info("请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
|
||||
okHttpClient.newWebSocket(request, listener);
|
||||
} catch (Exception e) {
|
||||
log.error("Spark AI 请求异常: {}", e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成鉴权URL
|
||||
*/
|
||||
public static String genAuthUrl(String endpoint, String apiKey, String apiSecret) {
|
||||
URL url = null;
|
||||
String date = null;
|
||||
String preStr = null;
|
||||
Mac mac = null;
|
||||
try {
|
||||
url = new URL(endpoint);
|
||||
// 时间
|
||||
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
|
||||
format.setTimeZone(TimeZone.getTimeZone("GMT"));
|
||||
date = format.format(new Date());
|
||||
// 拼接
|
||||
preStr = "host: " + url.getHost() + "\n" +
|
||||
"date: " + date + "\n" +
|
||||
"GET " + url.getPath() + " HTTP/1.1";
|
||||
// SHA256加密
|
||||
mac = Mac.getInstance("hmacsha256");
|
||||
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
|
||||
mac.init(spec);
|
||||
} catch (Exception e) {
|
||||
log.error("生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
|
||||
// Base64加密
|
||||
String sha = Base64.getEncoder().encodeToString(hexDigits);
|
||||
// 拼接
|
||||
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
|
||||
// 拼接地址
|
||||
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().
|
||||
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).
|
||||
addQueryParameter("date", date).//
|
||||
addQueryParameter("host", url.getHost()).//
|
||||
build();
|
||||
return httpUrl.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
init();
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package cn.teammodel.ai;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-15 14:29
|
||||
*/
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "spark.gpt")
|
||||
public class SparkGptProperties {
|
||||
private String endpoint;
|
||||
private String appId;
|
||||
private String apiKey;
|
||||
private String apiSecret;
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package cn.teammodel.ai;
|
||||
|
||||
import lombok.experimental.UtilityClass;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
@UtilityClass
|
||||
@Slf4j
|
||||
public class SseHelper {
|
||||
/**
|
||||
* 断开 server 与 client 的 sse 连接
|
||||
*/
|
||||
public void complete(SseEmitter sseEmitter) {
|
||||
try {
|
||||
sseEmitter.complete();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 向 client 发送消息
|
||||
*/
|
||||
public void send(SseEmitter sseEmitter, Object data) {
|
||||
try {
|
||||
sseEmitter.send(data);
|
||||
} catch (Exception e) {
|
||||
log.error("sseEmitter send error", e);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package cn.teammodel.ai.listener;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import cn.teammodel.ai.SseHelper;
|
||||
import cn.teammodel.ai.domain.SparkChatResponse;
|
||||
import cn.teammodel.common.ErrorCode;
|
||||
import cn.teammodel.config.exception.ServiceException;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.WebSocket;
|
||||
import okhttp3.WebSocketListener;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* okhttp 调用 ws 接口时的 listner
|
||||
* @author winter
|
||||
* @create 2023-12-15 16:17
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Data
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class SparkGptStreamListener extends WebSocketListener {
|
||||
private final SseEmitter sseEmitter;
|
||||
private String requestJson;
|
||||
/**
|
||||
* 模型响应的完整回复
|
||||
*/
|
||||
private String answer = "";
|
||||
|
||||
@Override
|
||||
public void onOpen(WebSocket webSocket, @NotNull Response response) {
|
||||
// ws 建立连接后发送请求的数据, 在 onMessage 中接受 server 的响应数据
|
||||
webSocket.send(requestJson);
|
||||
// 执行成功回调
|
||||
try {
|
||||
onOpen.accept(webSocket);
|
||||
} catch (Exception e) {
|
||||
// todo: 这儿不应该直接调她
|
||||
this.onFailure(webSocket, e, response);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(WebSocket webSocket, String text) {
|
||||
// 向 sse 推送消息(一次 onMessage 事件推送一次)
|
||||
SparkChatResponse sparkChatResponse = JSONUtil.toBean(text, SparkChatResponse.class);
|
||||
// 请求是否异常
|
||||
if (sparkChatResponse.getHeader().getCode() != 0) {
|
||||
this.onFailure(
|
||||
webSocket,
|
||||
new ServiceException(ErrorCode.SYSTEM_ERROR.getCode(),
|
||||
sparkChatResponse.getHeader().getMessage()),
|
||||
null
|
||||
);
|
||||
}
|
||||
// 推送回答 segment
|
||||
String msgSegment = sparkChatResponse.getPayload().getChoices().getText().get(0).getContent();
|
||||
// 处理消息格式(空格和换行符)
|
||||
msgSegment = StrUtil.replace(msgSegment, " ", " ").replaceAll("\n", "\n");
|
||||
answer += msgSegment;
|
||||
SseHelper.send(sseEmitter, msgSegment);
|
||||
|
||||
// 处理模型的最终响应
|
||||
if (sparkChatResponse.getHeader().getStatus() == 2) {
|
||||
// 其实 spark 会主动断开连接
|
||||
webSocket.close(1000, "done");
|
||||
onComplete.accept(answer);
|
||||
SseHelper.complete(sseEmitter);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
|
||||
webSocket.close(1000, t.getMessage());
|
||||
this.onError.accept(t);
|
||||
// 失败时结束 sse 连接
|
||||
SseHelper.send(sseEmitter,t.getMessage() + "[DONE]");
|
||||
SseHelper.complete(sseEmitter);
|
||||
}
|
||||
|
||||
|
||||
// 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调
|
||||
protected Consumer<WebSocket> onOpen = (s) -> {};
|
||||
protected Consumer<Throwable> onError = (s) -> {};
|
||||
protected Consumer<String> onComplete = (s) -> {};
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package cn.teammodel.common;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-14 12:19
|
||||
*/
|
||||
public interface FiveEducations {
|
||||
String MORAL = "美德";
|
||||
String INTELLECTUAL = "美智";
|
||||
String PHYSICAL = "体育";
|
||||
String AESTHETIC = "美艺";
|
||||
String LABOUR = "劳动";
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
package cn.teammodel.controller.admin.controller;
|
||||
|
||||
import cn.teammodel.common.R;
|
||||
import cn.teammodel.controller.admin.service.AdminAppraiseService;
|
||||
import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
|
||||
import cn.teammodel.model.entity.appraise.AchievementRule;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-13 16:07
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("admin/appraise")
|
||||
public class AdminAppraiseController {
|
||||
@Resource
|
||||
private AdminAppraiseService adminAppraiseService;
|
||||
|
||||
@PostMapping("updateAchieveRule")
|
||||
@ApiOperation("更新的 rule 节点将会直接覆盖老节点")
|
||||
public R<List<AchievementRule>> updateAchieveRule(@RequestBody UpdateAchievementRuleDto ruleDto) {
|
||||
List<AchievementRule> res = adminAppraiseService.updateAchieveRule(ruleDto);
|
||||
return R.success(res);
|
||||
}
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package cn.teammodel.controller.frontend;
|
||||
|
||||
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
|
||||
import cn.teammodel.service.ChatMessageService;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.validation.Valid;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/public/ai")
|
||||
public class AiController {
|
||||
@Resource
|
||||
private ChatMessageService chatMessageService;
|
||||
|
||||
@PostMapping("chat/completion")
|
||||
@ApiOperation("与 spark 的流式对话")
|
||||
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
|
||||
return chatMessageService.chatCompletion(chatCompletionReqDto);
|
||||
}
|
||||
@GetMapping("test/completion")
|
||||
@ApiOperation("与 spark 的流式对话")
|
||||
public SseEmitter testChatCompletion() {
|
||||
return chatMessageService.chatCompletion(null);
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package cn.teammodel.controller;
|
||||
package cn.teammodel.controller.frontend;
|
||||
|
||||
import cn.teammodel.common.R;
|
||||
import cn.teammodel.dao.AppraiseRepository;
|
@ -0,0 +1,45 @@
|
||||
package cn.teammodel.model.dto.admin;
|
||||
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-13 17:34
|
||||
*/
|
||||
@Data
|
||||
public class UpdateAchievementRuleDto {
|
||||
@NotNull
|
||||
@ApiModelProperty("学段 id")
|
||||
private String periodId;
|
||||
@ApiModelProperty("更新的 rule 节点: 将会直接覆盖老节点")
|
||||
private UpdateRule updateRule;
|
||||
|
||||
@Data
|
||||
public static class UpdateRule {
|
||||
@NotNull
|
||||
private String id;
|
||||
/**
|
||||
* 等级名称
|
||||
*/
|
||||
@NotNull
|
||||
private String name;
|
||||
/**
|
||||
* 等级 logo
|
||||
*/
|
||||
@NotNull
|
||||
private String logo;
|
||||
/**
|
||||
* 每次所需表扬数
|
||||
*/
|
||||
@NotNull
|
||||
private Integer levelCount;
|
||||
/**
|
||||
* 晋级所需下一等级所需当前等级次数
|
||||
*/
|
||||
@NotNull
|
||||
private Integer promotionLevel;
|
||||
}
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package cn.teammodel.model.dto.ai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotBlank;
|
||||
|
||||
@Data
|
||||
public class ChatCompletionReqDto {
|
||||
private Long sessionId;
|
||||
/**
|
||||
* 预设的会话面具
|
||||
*/
|
||||
private Long appId;
|
||||
@NotBlank(message = "请输入消息内容")
|
||||
private String text;
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package cn.teammodel.model.entity.appraise;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 成就的晋级规则
|
||||
* @author winter
|
||||
* @create 2023-12-13 15:23
|
||||
*/
|
||||
@Data
|
||||
public class AchievementRule {
|
||||
private String id;
|
||||
/**
|
||||
* 等级名称
|
||||
*/
|
||||
private String name;
|
||||
/**
|
||||
* 等级 logo
|
||||
*/
|
||||
private String logo;
|
||||
/**
|
||||
* 等级顺序
|
||||
*/
|
||||
private Integer level;
|
||||
/**
|
||||
* 每次所需表扬数
|
||||
*/
|
||||
private Integer levelCount;
|
||||
/**
|
||||
* 晋级所需下一等级所需当前等级次数
|
||||
*/
|
||||
private Integer promotionLevel;
|
||||
/**
|
||||
* 晋升到下一等级所需总表扬树
|
||||
*/
|
||||
private Integer promotionCount;
|
||||
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package cn.teammodel.model.entity.school;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 学校的配置项
|
||||
* @author winter
|
||||
* @create 2023-12-13 11:44
|
||||
*/
|
||||
@Data
|
||||
public class SchoolConfig {
|
||||
/**
|
||||
* 学校积分名字
|
||||
*/
|
||||
private String scoreName = "醍摩豆";
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package cn.teammodel.model.vo.appraise;
|
||||
|
||||
import cn.teammodel.model.entity.appraise.AchievementRule;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 学生个人评价 vo
|
||||
* @author winter
|
||||
* @create 2023-12-04 15:26
|
||||
*/
|
||||
@Data
|
||||
public class StudentReportVo {
|
||||
private Integer praiseCount;
|
||||
private Integer score;
|
||||
private Map<String, Integer> praiseDistribution;
|
||||
private Map<String, Integer> criticalDistribution;
|
||||
private AchievementRule curAchievement;
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package cn.teammodel.service;
|
||||
|
||||
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-18 15:20
|
||||
*/
|
||||
public interface ChatMessageService {
|
||||
/**
|
||||
* AI 聊天
|
||||
*/
|
||||
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto);
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
package cn.teammodel.service.impl;
|
||||
|
||||
import cn.teammodel.ai.SparkGptClient;
|
||||
import cn.teammodel.ai.SseHelper;
|
||||
import cn.teammodel.ai.domain.SparkChatRequestParam;
|
||||
import cn.teammodel.ai.listener.SparkGptStreamListener;
|
||||
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
|
||||
import cn.teammodel.model.entity.User;
|
||||
import cn.teammodel.security.utils.SecurityUtil;
|
||||
import cn.teammodel.service.ChatMessageService;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author winter
|
||||
* @create 2023-12-18 15:20
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ChatMessageServiceImpl implements ChatMessageService {
|
||||
@Resource
|
||||
private SparkGptClient sparkGptClient;
|
||||
|
||||
@Override
|
||||
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
|
||||
// 目前仅使用讯飞星火大模型
|
||||
User user = SecurityUtil.getLoginUser();
|
||||
String userId = user.getId();
|
||||
String text = chatCompletionReqDto.getText();
|
||||
// String userId = "123";
|
||||
// String text = "hello, how should I call you?";
|
||||
SseEmitter sseEmitter = new SseEmitter(-1L);
|
||||
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
|
||||
|
||||
// open 回调
|
||||
listener.setOnOpen((s) -> {
|
||||
// 敏感词检查,计费
|
||||
log.info("callback: ws open event emmit");
|
||||
});
|
||||
// 对话完成的回调
|
||||
listener.setOnComplete((s) -> {
|
||||
log.info("callback: ws complete event emmit");
|
||||
SseHelper.send(sseEmitter, "[DONE]");
|
||||
// 处理完成后的事件: 保存消息记录
|
||||
});
|
||||
// 错误的回调
|
||||
listener.setOnError((s) -> {
|
||||
log.error("callback: ws error" );
|
||||
// 返还积分
|
||||
});
|
||||
// todo: 拉取对话上下文
|
||||
List<SparkChatRequestParam.Message> messageList = Lists.newArrayList();
|
||||
messageList.add(SparkChatRequestParam.Message.ofUser(text));
|
||||
// todo: sessionId
|
||||
SparkChatRequestParam requestParam = SparkChatRequestParam
|
||||
.builder()
|
||||
.uid(userId)
|
||||
.chatId("123")
|
||||
.messageList(messageList)
|
||||
.build();
|
||||
sparkGptClient.streamChatCompletion(requestParam, listener);
|
||||
return sseEmitter;
|
||||
}
|
||||
}
|
Loading…
Reference in new issue