package cn.teammodel.service.impl; import cn.teammodel.ai.JsonLoader; import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SseHelper; import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import cn.teammodel.common.ErrorCode; import cn.teammodel.common.PK; import cn.teammodel.config.exception.ServiceException; import cn.teammodel.model.dto.ai.*; import cn.teammodel.model.dto.ai.comment.*; import cn.teammodel.model.entity.ai.ChatApp; import cn.teammodel.repository.ChatAppRepository; import cn.teammodel.repository.ChatSessionRepository; import cn.teammodel.model.entity.User; import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatMessageService; import cn.teammodel.utils.RepositoryUtil; import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.TypeReference; import com.azure.cosmos.models.CosmosPatchOperations; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.time.Instant; import java.util.*; /** * @author winter * @create 2023-12-18 15:20 */ @Service @Slf4j public class ChatMessageServiceImpl implements ChatMessageService { @Resource private SparkGptClient sparkGptClient; @Resource private ChatSessionRepository chatSessionRepository; @Resource private ChatAppRepository chatAppRepository; @Resource private JsonLoader jsonLoader; @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) { // 目前仅使用讯飞星火大模型 String appId = chatCompletionReqDto.getAppId(); SseEmitter sseEmitter; if (StringUtils.isEmpty(appId)) { sseEmitter = completionBySession(chatCompletionReqDto, userId); } else { sseEmitter = completionByApp(chatCompletionReqDto, false); } return sseEmitter; } /** * 评语调用 聊天模型 * * @param chatCommentsDto * @param userId * @param userName * @return */ @Override public SseEmitter chatComments(ChatCommentsDto chatCommentsDto, String userId, String userName) { try { // 目前仅使用讯飞星火大模型 String appId = chatCommentsDto.getAppId(); // 获取模板文本 String text = commentsTemplate(chatCommentsDto); if (!StringUtils.isEmpty(text)) { chatCommentsDto.setText(text); } else { log.info("参数错误"); throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "参数错误"); } SseEmitter sseEmitter; if (StringUtils.isEmpty(appId)) { sseEmitter = commentsBySession(chatCommentsDto, userId, userName); } else { ChatCompletionReqDto chatCompletionReqDto = new ChatCompletionReqDto(); chatCompletionReqDto.setAppId(chatCommentsDto.getAppId()); chatCompletionReqDto.setSessionId(chatCommentsDto.getSessionId()); chatCompletionReqDto.setText(chatCommentsDto.getText()); sseEmitter = completionByApp(chatCompletionReqDto, false); } return sseEmitter; } catch (Exception e) { log.info(Arrays.toString(e.getStackTrace())); log.error("{}-{}", e.getMessage(), Arrays.toString(e.getStackTrace())); throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), Arrays.toString(e.getStackTrace())); } } /** * 面具模式(暂时不存储聊天记录) */ private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) { String appId = chatCompletionReqDto.getAppId(); String userPrompt = chatCompletionReqDto.getText(); User user = SecurityUtil.getLoginUser(); String userId = user.getId(); String schoolId = user.getSchoolId(); // 查询 appId 获取 prompt // 通过 prompt 和 userprompt 生成结果 // 直接返回 ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在"); // 检验 app 是否可以被该用户使用 if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用"); } String appPrompt = chatApp.getPrompt(); SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: if (!justApi) { // 保存消息记录, 缓存更改 } }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error, info: " + s); // 返还积分 }); List messageList = new ArrayList<>(); messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt)); messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt)); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId(appId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } /** * 会话模式 */ private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) { String userPrompt = chatCompletionReqDto.getText(); String sessionId = chatCompletionReqDto.getSessionId(); ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在"); if (!session.getUserId().equals(userId)) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在"); } SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录, 缓存更改 ChatSession.Message message = ChatSession.Message.of(userPrompt, s); HistoryCache.updateContext(sessionId, message); CosmosPatchOperations options = CosmosPatchOperations.create() .replace("/updateTime", Instant.now().toEpochMilli()) .add("/history/-", message); chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error, info: " + s); // 返还积分 }); List messageList = fetchContext(sessionId, userPrompt); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId(sessionId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } /** * 评语 会话模式 */ private SseEmitter commentsBySession(ChatCommentsDto chatCommentsDto, String userId, String userName) { String userPrompt = chatCommentsDto.getText(); Object data = chatCommentsDto.getData(); //获取会话id 看是否有sessionId 有则直接赋值 没有则赋值userId String sessionId = chatCommentsDto.getSessionId(); ChatSession session = null; List sessions = chatSessionRepository.findCommentsById(sessionId); if (sessions.size() == 0) { // 初始化欢迎语 ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," + "你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); List history = Collections.singletonList(message); session = new ChatSession(); session.setId(sessionId); session.setCode(PK.CHAT_SESSION); session.setTitle("评语"); session.setUserId(userId); session.setCreateTime(Instant.now().toEpochMilli()); session.setUpdateTime(Instant.now().toEpochMilli()); session.setHistory(history); chatSessionRepository.save(session); } else { session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在"); } /* ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在"); if (!session.getUserId().equals(userId)) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在"); }*/ SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录, 缓存更改 ChatSession.Message message = ChatSession.Message.of(userPrompt, s); HistoryCache.updateContext(sessionId, message); CosmosPatchOperations options = CosmosPatchOperations.create() .replace("/updateTime", Instant.now().toEpochMilli()) .add("/history/-", message); chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error, info: " + s); // 返还积分 }); List messageList = fetchContext(userId, userPrompt); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId(userId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } List fetchContext(String userId, String prompt) { List context = HistoryCache.getContext(userId); List paramMessages = new ArrayList<>(); // 暂未缓存,从数据库拉取 if (ObjectUtils.isEmpty(context)) { context = chatSessionRepository.findLatestMessage(userId); if (ObjectUtils.isNotEmpty(context)) { HistoryCache.putContext(userId, context); } } // convert DB Message to Spark Message context.forEach(item -> { paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText())); paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText())); }); paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt)); return paramMessages; } /** * 评语调用 聊天模型 * 待优化 * * @return */ private String commentsTemplate(ChatCommentsDto chatCommentsDto) { try { StringBuilder builder = new StringBuilder(); String strData = JSON.toJSONString(chatCommentsDto.getData()); //检查是否角色 int phase = Math.max(chatCommentsDto.getPhase(), 0); List chatModels = new ArrayList<>(); ChatModelDto chatModel = null; // 获取模型数据 chatModels = jsonLoader.myJsonDataBean(); //循环查找对应的模型数据 for (ChatModelDto chatModelTemp : chatModels) { //判断评语类型 if (chatCommentsDto.getType().equals(chatModelTemp.getType())) { chatModel = chatModelTemp; break; } } String chatName = ""; switch(phase){ case 1: if (chatCommentsDto.getName().contains("班") || chatCommentsDto.getName().contains("班级")) { chatName = chatCommentsDto.getName(); } else { chatName = chatCommentsDto.getName()+"班级"; } break; case 2: if (chatCommentsDto.getName().contains("年级")) { chatName = chatCommentsDto.getName(); } else { chatName = chatCommentsDto.getName()+"年级"; } break; default: chatName = chatCommentsDto.getName()+"同学"; break; } if (chatModel != null) { //角色条件 builder.append(String.format(chatModel.getRole().get(0), chatCommentsDto.getPeriod(), chatCommentsDto.getSubject(), chatName)); } ChatModelDto finalChatModel = chatModel; //模版 switch (chatCommentsDto.getType()) { //智育 总体评语模版 case "wisdom": { WisdomCommentsDto wisdomComments; //转换问题 try { //转换方式 wisdomComments = JSON.parseObject(strData, WisdomCommentsDto.class); } catch (Exception e) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败"); } if (wisdomComments.getName() == null) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } String chat = phase > 0 ? finalChatModel.getChat().get(1) : finalChatModel.getChat().get(0); String cycleChat = finalChatModel.getCycleChats().get(0); // 使用String.format方法生成最终的字符串 builder.append(String.format( chat, chatCommentsDto.getName(), wisdomComments.getYear(), wisdomComments.getSemester(), wisdomComments.getJoinAll(), wisdomComments.getLessonMiddle(), wisdomComments.getOnLine(), wisdomComments.getMarking(), wisdomComments.getLevel(), wisdomComments.getProportion(), wisdomComments.getName() )); wisdomComments.getDims().forEach(item -> { if(item.data.length >= 4) { builder.append(String.format(cycleChat, item.name, item.data[0], item.data[1], item.data[2], item.data[3], item.data[4] )); } }); //builder.append(String.format(finalChatModel.getEnd().get(0),chatName, wisdomComments.getDims().size())); break; } //智育 表现模版 case "wisdomExam": { List examComments = new ArrayList<>(); try { examComments = JSON.parseObject(strData, new TypeReference>() { }); } catch (Exception e) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败"); } if (examComments.size() <= 1) { if (examComments.isEmpty()) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } if (examComments.get(0).name == null) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } } int count = 1; if(finalChatModel.getChat() != null && !chatCommentsDto.getName().isEmpty()){ //昵称 builder.append(String.format(finalChatModel.getChat().get(0), chatCommentsDto.getName())); } switch(phase){ case 1:{ for (WisdomExamCommentsDto examComment : examComments) { builder.append(String.format(finalChatModel.getCycleChats().get(1), count, examComment.getName(),examComment.getTime(), examComment.getClassRate(),examComment.getGradeRate())); count++; } break; } case 2:{ for (WisdomExamCommentsDto examComment : examComments) { builder.append(String.format(finalChatModel.getCycleChats().get(2), count, examComment.getName(),examComment.getTime(), examComment.getGradeRate())); count++; } break; } default:{ for (WisdomExamCommentsDto examComment : examComments) { builder.append(String.format(finalChatModel.getCycleChats().get(0), count, examComment.getName(),examComment.getTime(), examComment.getScore(), examComment.getScoreRate(),examComment.getRanking())); count++; } break; } } break; } // 智育 学科评语模版 case "wisdomSubject": { List subjectComments; try { subjectComments = JSON.parseObject(strData, new TypeReference>() {}); } catch (Exception e) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } if (subjectComments.isEmpty()) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } if(finalChatModel.getChat() != null && !chatCommentsDto.getName().isEmpty()){ builder.append(String.format(finalChatModel.getChat().get(1), chatCommentsDto.getName())); } else { if (finalChatModel.getChat() != null) { builder.append(finalChatModel.getChat().get(0)); } } String name = ""; //拼接学科数组 for (WisdomSubjectComments comments : subjectComments) { builder.append(String.format(finalChatModel.getCycleChats().get(0), comments.subjectName)); for (WisdomSubjectComments.StuInfo stuInfo : comments.getRankings()) { builder.append(String.format(finalChatModel.getCycleChats().get(1), stuInfo.ranking,stuInfo.name, stuInfo.scoreRate*100)); } if (phase == 0) { name = comments.getClaasRanking().name; builder.append(String.format(finalChatModel.getCycleChats().get(2), name, comments.getClaasRanking().ranking, comments.getClaasRanking().scoreRate * 100, comments.getClaasRanking().average * 100)); } } //builder.append(String.format(finalChatModel.getEnd().get(0), subjectComments.size(),chatName)); break; } //艺术 考核指标纬度评语 case "artLatitude":{ List artLatitudes; try { artLatitudes = JSON.parseObject(strData, new TypeReference>() { }); } catch (Exception e) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } if (artLatitudes.size() <= 1) { if (artLatitudes.isEmpty()) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } } if(finalChatModel.getChat() != null && !chatCommentsDto.getName().isEmpty()) { builder.append(String.format(finalChatModel.getChat().get(1), chatCommentsDto.getName())); }else { if (finalChatModel.getChat() != null) { builder.append(finalChatModel.getChat().get(0)); } } for (ArtLatitudeDto artLatitude : artLatitudes){ builder.append(String.format(finalChatModel.getCycleChats().get(0), artLatitude.getQuotaN1(),artLatitude.getQuotaP1(), artLatitude.getQuotaN2(),artLatitude.getQuotaP2(), artLatitude.getQuotaN3(),artLatitude.getQuotaP3(), artLatitude.getPercent(),artLatitude.getLevel())); } //builder.append(String.format(finalChatModel.getEnd().get(0),chatName)); break; } //艺术 学科评语模版 case "artSubject":{ List artSubjects; try { artSubjects = JSON.parseObject(strData, new TypeReference>() {}); } catch (Exception e) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } if (artSubjects.size() <= 1) { if (artSubjects.isEmpty()) { throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常"); } } builder.append(String.format(finalChatModel.getChat().get(0), chatCommentsDto.getSubject())); for (ArtSubjectDto artSubject : artSubjects){ builder.append(String.format(finalChatModel.getCycleChats().get(0), artSubject.getDimension(),artSubject.getBlock(), artSubject.getPoint(),artSubject.getScore(), artSubject.getTotalScore(),artSubject.getPercent())); } //builder.append(String.format(finalChatModel.getEnd().get(0),chatName)); break; } //体育 case "sport":{ builder.append("请按照以下格式回复:\n"); builder.append("1. 运动作品:\n"); builder.append("2. 运动作品说明:\n"); builder.append("3. 运动作品示例:\n"); builder.append("4. 运动作品示例说明:\n"); break; } //德育 case "moral":{ builder.append(finalChatModel.getChat()); break; } default: throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "评语类型异常"); } String limitChat = "。限制条件如下:1.%s字左右;2.直接返回评价的内容;3.评价开头不要提示语;4.评价开头不允许出现特殊字符;"; int size = chatCommentsDto.getSize() > 0 ? chatCommentsDto.getSize() : 200; builder.append(String.format(limitChat, size)); //用户自定义限制条件 if (chatCommentsDto.getLimitTexts() != null && !chatCommentsDto.getLimitTexts().isEmpty()) { int serialNumber = 5; //用户自定义限制条件数量标识 List limitTexts = chatCommentsDto.getLimitTexts(); //获取自定义限制条件内容 int length = limitTexts.size(); //数组大小 for (int i = 0; i < length; i++) { String str = limitTexts.get(i); builder.append(serialNumber).append(".").append(str); if (i < length - 1) { builder.append(";"); }else { builder.append("。"); } serialNumber += 1; } } return builder.toString(); } catch (Exception e) { log.info(Arrays.toString(e.getStackTrace())); log.error("{}-{}", e.getMessage(), Arrays.toString(e.getStackTrace())); throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), Arrays.toString(e.getStackTrace())); } } }