You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
IESExtension/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java

499 lines
22 KiB

package cn.teammodel.service.impl;
import cn.teammodel.ai.JsonLoader;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.model.dto.ai.*;
import cn.teammodel.model.dto.ai.comment.ChatCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomExamCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomSubjectComments;
import cn.teammodel.model.entity.ai.ChatApp;
import cn.teammodel.repository.ChatAppRepository;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatMessageService;
import cn.teammodel.utils.FileUtil;
import cn.teammodel.utils.RepositoryUtil;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.TypeReference;
import com.azure.cosmos.models.CosmosPatchOperations;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.*;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.Instant;
import java.util.*;
/**
* @author winter
* @create 2023-12-18 15:20
*/
@Service
@Slf4j
public class ChatMessageServiceImpl implements ChatMessageService {
@Resource
private SparkGptClient sparkGptClient;
@Resource
private ChatSessionRepository chatSessionRepository;
@Resource
private ChatAppRepository chatAppRepository;
@Resource
private JsonLoader jsonLoader;
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = completionBySession(chatCompletionReqDto, userId);
} else {
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
return sseEmitter;
}
/**
*
*
* @param chatCommentsDto
* @param userId
* @param userName
* @return
*/
@Override
public SseEmitter chatComments(ChatCommentsDto chatCommentsDto, String userId, String userName) {
try {
// 目前仅使用讯飞星火大模型
String appId = chatCommentsDto.getAppId();
String text = commentsTemplate(chatCommentsDto);
if (!StringUtils.isEmpty(text)) {
chatCommentsDto.setText(text);
} else {
log.info("参数错误");
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "参数错误");
}
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = commentsBySession(chatCommentsDto, userId, userName);
} else {
ChatCompletionReqDto chatCompletionReqDto = new ChatCompletionReqDto();
chatCompletionReqDto.setAppId(chatCommentsDto.getAppId());
chatCompletionReqDto.setSessionId(chatCommentsDto.getSessionId());
chatCompletionReqDto.setText(chatCommentsDto.getText());
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
return sseEmitter;
} catch (Exception e) {
log.info(Arrays.toString(e.getStackTrace()));
log.error("{}-{}", e.getMessage(), Arrays.toString(e.getStackTrace()));
throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), Arrays.toString(e.getStackTrace()));
}
}
/**
* ()
*/
private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) {
String appId = chatCompletionReqDto.getAppId();
String userPrompt = chatCompletionReqDto.getText();
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
String schoolId = user.getSchoolId();
// 查询 appId 获取 prompt
// 通过 prompt 和 userprompt 生成结果
// 直接返回
ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在");
// 检验 app 是否可以被该用户使用
if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用");
}
String appPrompt = chatApp.getPrompt();
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件:
if (!justApi) {
// 保存消息记录, 缓存更改
}
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = new ArrayList<>();
messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt));
messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt));
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(appId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
/**
*
*/
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) {
String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId();
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
}
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
/**
*
*/
private SseEmitter commentsBySession(ChatCommentsDto chatCommentsDto, String userId, String userName) {
String userPrompt = chatCommentsDto.getText();
Object data = chatCommentsDto.getData();
//获取会话id 看是否有sessionId 有则直接赋值 没有则赋值userId
String sessionId = chatCommentsDto.getSessionId();
ChatSession session = null;
List<ChatSession> sessions = chatSessionRepository.findCommentsById(sessionId);
if (sessions.size() == 0) {
// 初始化欢迎语
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," +
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<ChatSession.Message> history = Collections.singletonList(message);
session = new ChatSession();
session.setId(sessionId);
session.setCode(PK.CHAT_SESSION);
session.setTitle("评语");
session.setUserId(userId);
session.setCreateTime(Instant.now().toEpochMilli());
session.setUpdateTime(Instant.now().toEpochMilli());
session.setHistory(history);
chatSessionRepository.save(session);
} else {
session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
}
/*
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
}*/
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(userId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
List<SparkChatRequestParam.Message> fetchContext(String userId, String prompt) {
List<ChatSession.Message> context = HistoryCache.getContext(userId);
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
// 暂未缓存,从数据库拉取
if (ObjectUtils.isEmpty(context)) {
context = chatSessionRepository.findLatestMessage(userId);
if (ObjectUtils.isNotEmpty(context)) {
HistoryCache.putContext(userId, context);
}
}
// convert DB Message to Spark Message
context.forEach(item -> {
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
});
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
return paramMessages;
}
/**
*
*
*
* @return
*/
private String commentsTemplate(ChatCommentsDto chatCommentsDto) {
StringBuilder builder = new StringBuilder();
String strData = JSON.toJSONString(chatCommentsDto.getData());
List<ChatModelDto> chatModels = new ArrayList<>();
ChatModelDto chatModel = null;
// 获取模型数据
chatModels = jsonLoader.myJsonDataBean();
/*
//验证获取模型数据 异常问题
try {
String fileText = FileUtil.getFileText("Json/ChatModel.json");
8 months ago
String jsonData = JSON.toJSONString(fileText);
//获取聊天字段中的数据
Object obj = JSON.parseObject(jsonData).get("chatModel");
String jsonData01 = JSON.toJSONString(obj);
//转换方式
chatModels = JSON.parseObject(jsonData01, new TypeReference<List<ChatModelDto>>() {});
8 months ago
log.info("获取地址fileText"+fileText+"----文件内容Data:"+ jsonData +"----获取模型集合Object" + obj +"----获取模型集合String"+ jsonData01 +"----获取模型集合机构:"+chatModels);
} catch (Exception e) {
throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), "读取文件" + Arrays.toString(e.getStackTrace()) + e.getMessage());
}*/
//循环查找对应的模型数据
for (ChatModelDto chatModelTemp : chatModels) {
//判断评语类型
if (chatCommentsDto.getType().equals(chatModelTemp.getType())) {
chatModel = chatModelTemp;
break;
}
}
if (chatModel != null) {
//角色条件
8 months ago
builder.append(String.format(chatModel.getRole(), chatCommentsDto.getPeriod(), chatCommentsDto.getSubject()));
}
ChatModelDto finalChatModel = chatModel;
//模版
switch (chatCommentsDto.getType()) {
//智育 总体评语模版
case "wisdom": {
WisdomCommentsDto wisdomComments;
//转换问题
try {
//转换方式
wisdomComments = JSON.parseObject(strData, WisdomCommentsDto.class);
} catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败");
}
if (wisdomComments.getName() == null) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
// 使用String.format方法生成最终的字符串
builder.append(String.format(
finalChatModel.getChat(),
wisdomComments.getName(),
wisdomComments.getYear(),
wisdomComments.getSemester(),
wisdomComments.getJoinAll(),
wisdomComments.getLessonMiddle(),
wisdomComments.getOnLine(),
wisdomComments.getMarking(),
wisdomComments.getLevel(),
wisdomComments.getProportion(),
wisdomComments.getName()
));
/*
//有平均分
builder.append(String.format(
finalChatModel.getChat(),
wisdomComments.getName(),
wisdomComments.getYear(),
wisdomComments.getSemester(),
wisdomComments.getJoinAll(),
wisdomComments.getLessonMiddle(),
wisdomComments.getOnLine(),
wisdomComments.getMarking(),
wisdomComments.getAverage(),
wisdomComments.getLevel(),
wisdomComments.getProportion(),
wisdomComments.getName()
));*/
wisdomComments.getDims().forEach(item -> {
builder.append(String.format(finalChatModel.getCycleChats().get(0),
item.name,
item.data[0],
item.data[1],
item.data[2],
item.data[3],
item.data[4]
));
});
builder.append(String.format(finalChatModel.getEnd(), wisdomComments.getName()));
break;
}
//智育 表现模版
case "wisdomExam": {
List<WisdomExamCommentsDto> examComments = new ArrayList<>();
try {
examComments = JSON.parseObject(strData, new TypeReference<List<WisdomExamCommentsDto>>() {
});
} catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败");
}
if (examComments.size() <= 1) {
if (examComments.isEmpty()) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
if (examComments.get(0).name == null) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
}
int count = 1;
for (WisdomExamCommentsDto examComment : examComments) {
builder.append(String.format(finalChatModel.getCycleChats().get(0), count, examComment.getName(),
examComment.getTime(), examComment.getScore(), examComment.getScoreRate(),
examComment.getRanking()));
if (count < examComments.size()) {
count++;
}
}
builder.append(String.format(finalChatModel.getEnd(), count));
break;
}
// 智育 学科评语模版
case "windomSubject": {
List<WisdomSubjectComments> subjectComments;
try {
subjectComments = JSON.parseObject(strData, new TypeReference<List<WisdomSubjectComments>>() {
});
} catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
if (subjectComments.size() <= 1) {
if (subjectComments.isEmpty()) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
}
builder.append(finalChatModel.getChat());
//拼接学科数组
for (WisdomSubjectComments comments : subjectComments) {
builder.append(String.format(finalChatModel.getCycleChats().get(0), comments.subjectName));
for (WisdomSubjectComments.StuInfo stuInfo : comments.getRankings()) {
builder.append(String.format(finalChatModel.getCycleChats().get(1), stuInfo.ranking,
stuInfo.name, stuInfo.scoreRate));
}
builder.append(String.format(finalChatModel.getCycleChats().get(2),
comments.getClaasRanking().ranking, comments.getClaasRanking().scoreRate,
comments.getClaasRanking().average));
}
builder.append(String.format(finalChatModel.getEnd(), subjectComments.size()));
break;
}
//艺术 考核指标纬度评语
case "artDimensions":
builder.append("请按照以下格式回复:\n");
builder.append("1. 艺术作品:\n");
builder.append("2. 艺术作品说明:\n");
builder.append("3. 艺术作品示例:\n");
builder.append("4. 艺术作品示例说明:\n");
break;
case "sport":
builder.append("请按照以下格式回复:\n");
builder.append("1. 运动作品:\n");
builder.append("2. 运动作品说明:\n");
builder.append("3. 运动作品示例:\n");
builder.append("4. 运动作品示例说明:\n");
break;
default:
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "评语类型异常");
}
if (chatCommentsDto.getSize() > 0) {
builder.append("字数限制在:")
.append(chatCommentsDto.getSize())
.append("字左右");
} else {
builder.append("字数限制在200字左右");
}
return builder.toString();
}
}