You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
IESExtension/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java

502 lines
22 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package cn.teammodel.service.impl;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.model.dto.ai.*;
import cn.teammodel.model.dto.ai.comment.ChatCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomExamCommentsDto;
import cn.teammodel.model.dto.ai.comment.WisdomSubjectComments;
import cn.teammodel.model.entity.ai.ChatApp;
import cn.teammodel.repository.ChatAppRepository;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatMessageService;
import cn.teammodel.utils.FileUtil;
import cn.teammodel.utils.RepositoryUtil;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.TypeReference;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.sun.xml.internal.ws.transport.http.ResourceLoader;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.*;
import java.net.URL;
import java.time.Instant;
import java.util.*;
/**
* @author winter
* @create 2023-12-18 15:20
*/
@Service
@Slf4j
public class ChatMessageServiceImpl implements ChatMessageService {
@Resource
private SparkGptClient sparkGptClient;
@Resource
private ChatSessionRepository chatSessionRepository;
@Resource
private ChatAppRepository chatAppRepository;
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = completionBySession(chatCompletionReqDto, userId);
} else {
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
return sseEmitter;
}
/**
* 评语调用 聊天模型
*
* @param chatCommentsDto
* @param userId
* @param userName
* @return
*/
@Override
public SseEmitter chatComments(ChatCommentsDto chatCommentsDto, String userId, String userName) {
// 目前仅使用讯飞星火大模型
String appId = chatCommentsDto.getAppId();
String text = commentsTemplate(chatCommentsDto);
if (!StringUtils.isEmpty(text)) {
chatCommentsDto.setText(text);
} else {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "参数错误");
}
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = commentsBySession(chatCommentsDto, userId, userName);
} else {
ChatCompletionReqDto chatCompletionReqDto = new ChatCompletionReqDto();
chatCompletionReqDto.setAppId(chatCommentsDto.getAppId());
chatCompletionReqDto.setSessionId(chatCommentsDto.getSessionId());
chatCompletionReqDto.setText(chatCommentsDto.getText());
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
return sseEmitter;
}
/**
* 面具模式(暂时不存储聊天记录)
*/
private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) {
String appId = chatCompletionReqDto.getAppId();
String userPrompt = chatCompletionReqDto.getText();
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
String schoolId = user.getSchoolId();
// 查询 appId 获取 prompt
// 通过 prompt 和 userprompt 生成结果
// 直接返回
ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在");
// 检验 app 是否可以被该用户使用
if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用");
}
String appPrompt = chatApp.getPrompt();
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件:
if (!justApi) {
// 保存消息记录, 缓存更改
}
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = new ArrayList<>();
messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt));
messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt));
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(appId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
/**
* 会话模式
*/
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) {
String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId();
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
}
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
/**
* 评语 会话模式
*/
private SseEmitter commentsBySession(ChatCommentsDto chatCommentsDto, String userId, String userName) {
String userPrompt = chatCommentsDto.getText();
Object data = chatCommentsDto.getData();
//获取会话id 看是否有sessionId 有则直接赋值 没有则赋值userId
String sessionId = chatCommentsDto.getSessionId();
ChatSession session = null;
List<ChatSession> sessions = chatSessionRepository.findCommentsById(sessionId);
if (sessions.size() == 0) {
// 初始化欢迎语
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," +
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<ChatSession.Message> history = Collections.singletonList(message);
session = new ChatSession();
session.setId(sessionId);
session.setCode(PK.CHAT_SESSION);
session.setTitle("评语");
session.setUserId(userId);
session.setCreateTime(Instant.now().toEpochMilli());
session.setUpdateTime(Instant.now().toEpochMilli());
session.setHistory(history);
chatSessionRepository.save(session);
} else {
session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
}
/*
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
}*/
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(userId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
List<SparkChatRequestParam.Message> fetchContext(String userId, String prompt) {
List<ChatSession.Message> context = HistoryCache.getContext(userId);
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
// 暂未缓存,从数据库拉取
if (ObjectUtils.isEmpty(context)) {
context = chatSessionRepository.findLatestMessage(userId);
if (ObjectUtils.isNotEmpty(context)) {
HistoryCache.putContext(userId, context);
}
}
// convert DB Message to Spark Message
context.forEach(item -> {
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
});
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
return paramMessages;
}
/**
* 评语调用 聊天模型
* 待优化
* @return
*/
private String commentsTemplate(ChatCommentsDto chatCommentsDto) {
StringBuilder builder = new StringBuilder();
String strData = JSON.toJSONString(chatCommentsDto.getData());
File file = new File(ClassLoader.getSystemResource("Json/ChatModel.json").getPath()); //相对路径获取文件信息
//File file = new File("src/main/resources/Json/ChatModel.json"); //绝对路径获取文件信息
List<ChatModelDto> chatModels = new ArrayList<>();
ChatModelDto chatModel = null;
//chatModel = readerMethod(file);
chatModels = readerMethod(file);
if(chatModels.size() <= 0)
{
throw new ServiceException(ErrorCode.NOT_FOUND_ERROR.getCode(), "评语模版未配置");
}
//循环查找对应的模型数据
for(ChatModelDto chatModelTemp : chatModels)
{
//判断评语类型
if(chatCommentsDto.getType().equals(chatModelTemp.getType()))
{
chatModel = chatModelTemp;
break;
}
}
if(chatModel != null)
{
//角色条件
builder.append(String.format(chatModel.getRole(),chatCommentsDto.getPeriod(),chatCommentsDto.getSubject()));
}
ChatModelDto finalChatModel = chatModel;
//模版
switch (chatCommentsDto.getType()) {
//智育 总体评语模版
case "wisdom":
{
WisdomCommentsDto wisdomComments;
//转换问题
try {
//转换方式
wisdomComments = JSON.parseObject(strData, WisdomCommentsDto.class);
} catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败");
}
if(wisdomComments.getName() == null)
{
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
// 使用String.format方法生成最终的字符串
builder.append(String.format(
finalChatModel.getChat(),
wisdomComments.getName(),
wisdomComments.getYear(),
wisdomComments.getSemester(),
wisdomComments.getJoinAll(),
wisdomComments.getLessonMiddle(),
wisdomComments.getOnLine(),
wisdomComments.getMarking(),
wisdomComments.getLevel(),
wisdomComments.getProportion(),
wisdomComments.getName()
));
/*
//有平均分
builder.append(String.format(
finalChatModel.getChat(),
wisdomComments.getName(),
wisdomComments.getYear(),
wisdomComments.getSemester(),
wisdomComments.getJoinAll(),
wisdomComments.getLessonMiddle(),
wisdomComments.getOnLine(),
wisdomComments.getMarking(),
wisdomComments.getAverage(),
wisdomComments.getLevel(),
wisdomComments.getProportion(),
wisdomComments.getName()
));*/
wisdomComments.getDims().forEach(item -> {
builder.append(String.format(finalChatModel.getCycleChats().get(0),
item.name,
item.data[0],
item.data[1],
item.data[2],
item.data[3],
item.data[4]
));
});
builder.append(String.format(finalChatModel.getEnd(),wisdomComments.getName()));
break;
}
//智育 表现模版
case "wisdomExam":
{
List<WisdomExamCommentsDto> examComments = new ArrayList<>();
try {
examComments = JSON.parseObject(strData, new TypeReference<List<WisdomExamCommentsDto>>() {});
} catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "类型转换失败");
}
if(examComments.size() <= 1){
if (examComments.isEmpty()){
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
if (examComments.get(0).name == null){
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
}
int count = 1;
for (WisdomExamCommentsDto examComment : examComments) {
builder.append(String.format(finalChatModel.getCycleChats().get(0), count, examComment.getName(), examComment.getTime(), examComment.getScore(), examComment.getScoreRate(), examComment.getBanking()));
if (count < examComments.size())
{
count++;
}
}
builder.append(String.format(finalChatModel.getEnd(),count));
break;
}
// 智育 学科评语模版
case "windomSubject":
{
List<WisdomSubjectComments> subjectComments;
try {
subjectComments = JSON.parseObject(strData, new TypeReference<List<WisdomSubjectComments>>() {});
}
catch (Exception e) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
if(subjectComments.size() <= 1){
if (subjectComments.isEmpty()){
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "请求参数异常");
}
}
builder.append(finalChatModel.getChat());
//拼接学科数组
for (WisdomSubjectComments comments : subjectComments){
builder.append(String.format(finalChatModel.getCycleChats().get(0), comments.subjectName));
for (WisdomSubjectComments.StuInfo stuInfo : comments.getRankings()){
builder.append(String.format(finalChatModel.getCycleChats().get(1), stuInfo.ranking, stuInfo.name, stuInfo.scoreRate));
}
builder.append(String.format(finalChatModel.getCycleChats().get(2), comments.getClaasRanking().ranking, comments.getClaasRanking().scoreRate,comments.getClaasRanking().average));
}
builder.append(String.format(finalChatModel.getEnd(),subjectComments.size()));
break;
}
//艺术 考核指标纬度评语
case "artDimensions":
builder.append("请按照以下格式回复:\n");
builder.append("1. 艺术作品:\n");
builder.append("2. 艺术作品说明:\n");
builder.append("3. 艺术作品示例:\n");
builder.append("4. 艺术作品示例说明:\n");
break;
case "sport":
builder.append("请按照以下格式回复:\n");
builder.append("1. 运动作品:\n");
builder.append("2. 运动作品说明:\n");
builder.append("3. 运动作品示例:\n");
builder.append("4. 运动作品示例说明:\n");
break;
default:
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "评语类型异常");
}
if (chatCommentsDto.getSize() > 0 ){
builder.append("字数限制在:")
.append(chatCommentsDto.getSize())
.append("字左右");
}
return builder.toString();
}
/**
* 读取文件信息并转换为智能对话模型数组对象
* @param file
* @return
*/
private static List<ChatModelDto> readerMethod(File file) {
//读取文件信息并返回string字符串 并改成json格式
String fileTxt = FileUtil.readFile(file);
String strData =JSON.toJSONString(fileTxt);
//获取聊天字段中的数据
Object str = JSON.parseObject(strData).get("chatModel");
String strData2 = JSON.toJSONString(str);
List<ChatModelDto> chatModelDtos = new ArrayList<>();
//转换问题
try {
//转换方式
chatModelDtos = JSON.parseObject(strData2, new TypeReference<List<ChatModelDto>>() {});
} catch (Exception e) {
throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), "类型转换失败");
}
return chatModelDtos;
}
}