|
|
|
package cn.teammodel.service.impl;
|
|
|
|
|
|
|
|
import cn.hutool.core.lang.UUID;
|
|
|
|
import cn.teammodel.ai.SparkGptClient;
|
|
|
|
import cn.teammodel.ai.SseHelper;
|
|
|
|
import cn.teammodel.ai.cache.HistoryCache;
|
|
|
|
import cn.teammodel.ai.domain.SparkChatRequestParam;
|
|
|
|
import cn.teammodel.ai.listener.SparkGptStreamListener;
|
|
|
|
import cn.teammodel.common.ErrorCode;
|
|
|
|
import cn.teammodel.common.PK;
|
|
|
|
import cn.teammodel.config.exception.ServiceException;
|
|
|
|
import cn.teammodel.model.entity.ai.ChatApp;
|
|
|
|
import cn.teammodel.repository.ChatAppRepository;
|
|
|
|
import cn.teammodel.repository.ChatSessionRepository;
|
|
|
|
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
|
|
|
|
import cn.teammodel.model.entity.User;
|
|
|
|
import cn.teammodel.model.entity.ai.ChatSession;
|
|
|
|
import cn.teammodel.security.utils.SecurityUtil;
|
|
|
|
import cn.teammodel.service.ChatMessageService;
|
|
|
|
import cn.teammodel.utils.RepositoryUtil;
|
|
|
|
import com.azure.cosmos.models.CosmosPatchOperations;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.ObjectUtils;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
|
|
import javax.annotation.Resource;
|
|
|
|
import java.time.Instant;
|
|
|
|
import java.util.ArrayList;
|
|
|
|
import java.util.Collections;
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @author winter
|
|
|
|
* @create 2023-12-18 15:20
|
|
|
|
*/
|
|
|
|
@Service
|
|
|
|
@Slf4j
|
|
|
|
public class ChatMessageServiceImpl implements ChatMessageService {
|
|
|
|
@Resource
|
|
|
|
private SparkGptClient sparkGptClient;
|
|
|
|
@Resource
|
|
|
|
private ChatSessionRepository chatSessionRepository;
|
|
|
|
@Resource
|
|
|
|
private ChatAppRepository chatAppRepository;
|
|
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
|
|
|
|
// 目前仅使用讯飞星火大模型
|
|
|
|
String appId = chatCompletionReqDto.getAppId();
|
|
|
|
SseEmitter sseEmitter;
|
|
|
|
if (StringUtils.isEmpty(appId)) {
|
|
|
|
sseEmitter = completionBySession(chatCompletionReqDto, userId);
|
|
|
|
} else {
|
|
|
|
sseEmitter = completionByApp(chatCompletionReqDto, false);
|
|
|
|
}
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 评语调用 聊天模型
|
|
|
|
* @param chatCompletionReqDto
|
|
|
|
* @param userId
|
|
|
|
* @return
|
|
|
|
*/
|
|
|
|
@Override
|
|
|
|
public SseEmitter chatComments(ChatCompletionReqDto chatCompletionReqDto, String userId,String userName) {
|
|
|
|
// 目前仅使用讯飞星火大模型
|
|
|
|
String appId = chatCompletionReqDto.getAppId();
|
|
|
|
SseEmitter sseEmitter;
|
|
|
|
if (StringUtils.isEmpty(appId)) {
|
|
|
|
sseEmitter = commentsBySession(chatCompletionReqDto, userId,userName);
|
|
|
|
} else {
|
|
|
|
sseEmitter = completionByApp(chatCompletionReqDto, false);
|
|
|
|
}
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 面具模式(暂时不存储聊天记录)
|
|
|
|
*/
|
|
|
|
private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) {
|
|
|
|
String appId = chatCompletionReqDto.getAppId();
|
|
|
|
String userPrompt = chatCompletionReqDto.getText();
|
|
|
|
User user = SecurityUtil.getLoginUser();
|
|
|
|
String userId = user.getId();
|
|
|
|
String schoolId = user.getSchoolId();
|
|
|
|
|
|
|
|
// 查询 appId 获取 prompt
|
|
|
|
// 通过 prompt 和 userprompt 生成结果
|
|
|
|
// 直接返回
|
|
|
|
ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在");
|
|
|
|
// 检验 app 是否可以被该用户使用
|
|
|
|
if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) {
|
|
|
|
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用");
|
|
|
|
}
|
|
|
|
|
|
|
|
String appPrompt = chatApp.getPrompt();
|
|
|
|
SseEmitter sseEmitter = new SseEmitter(-1L);
|
|
|
|
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
|
|
|
|
// open 回调
|
|
|
|
listener.setOnOpen((s) -> {
|
|
|
|
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
|
|
|
|
log.info("callback: ws open event emmit");
|
|
|
|
});
|
|
|
|
// 对话完成的回调
|
|
|
|
listener.setOnComplete((s) -> {
|
|
|
|
log.info("callback: ws complete event emmit");
|
|
|
|
SseHelper.send(sseEmitter, "[DONE]");
|
|
|
|
// 处理完成后的事件:
|
|
|
|
if (!justApi) {
|
|
|
|
// 保存消息记录, 缓存更改
|
|
|
|
}
|
|
|
|
});
|
|
|
|
// 错误的回调
|
|
|
|
listener.setOnError((s) -> {
|
|
|
|
log.error("callback: ws error, info: " + s);
|
|
|
|
// 返还积分
|
|
|
|
});
|
|
|
|
List<SparkChatRequestParam.Message> messageList = new ArrayList<>();
|
|
|
|
messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt));
|
|
|
|
messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt));
|
|
|
|
SparkChatRequestParam requestParam = SparkChatRequestParam
|
|
|
|
.builder()
|
|
|
|
.uid(userId)
|
|
|
|
.chatId(appId)
|
|
|
|
.messageList(messageList)
|
|
|
|
.build();
|
|
|
|
sparkGptClient.streamChatCompletion(requestParam, listener);
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 会话模式
|
|
|
|
*/
|
|
|
|
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) {
|
|
|
|
String userPrompt = chatCompletionReqDto.getText();
|
|
|
|
String sessionId = chatCompletionReqDto.getSessionId();
|
|
|
|
|
|
|
|
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
|
|
|
|
if (!session.getUserId().equals(userId)) {
|
|
|
|
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
|
|
|
|
}
|
|
|
|
|
|
|
|
SseEmitter sseEmitter = new SseEmitter(-1L);
|
|
|
|
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
|
|
|
|
// open 回调
|
|
|
|
listener.setOnOpen((s) -> {
|
|
|
|
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
|
|
|
|
log.info("callback: ws open event emmit");
|
|
|
|
});
|
|
|
|
// 对话完成的回调
|
|
|
|
listener.setOnComplete((s) -> {
|
|
|
|
log.info("callback: ws complete event emmit");
|
|
|
|
SseHelper.send(sseEmitter, "[DONE]");
|
|
|
|
// 处理完成后的事件: 保存消息记录, 缓存更改
|
|
|
|
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
|
|
|
|
HistoryCache.updateContext(sessionId, message);
|
|
|
|
CosmosPatchOperations options = CosmosPatchOperations.create()
|
|
|
|
.replace("/updateTime", Instant.now().toEpochMilli())
|
|
|
|
.add("/history/-", message);
|
|
|
|
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
|
|
|
|
});
|
|
|
|
// 错误的回调
|
|
|
|
listener.setOnError((s) -> {
|
|
|
|
log.error("callback: ws error, info: " + s);
|
|
|
|
// 返还积分
|
|
|
|
});
|
|
|
|
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
|
|
|
|
SparkChatRequestParam requestParam = SparkChatRequestParam
|
|
|
|
.builder()
|
|
|
|
.uid(userId)
|
|
|
|
.chatId(sessionId)
|
|
|
|
.messageList(messageList)
|
|
|
|
.build();
|
|
|
|
sparkGptClient.streamChatCompletion(requestParam, listener);
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 评语 会话模式
|
|
|
|
*/
|
|
|
|
private SseEmitter commentsBySession(ChatCompletionReqDto chatCompletionReqDto, String userId,String userName) {
|
|
|
|
String userPrompt = chatCompletionReqDto.getText();
|
|
|
|
String sessionId = chatCompletionReqDto.getSessionId();
|
|
|
|
|
|
|
|
ChatSession session = null;
|
|
|
|
List<ChatSession> sessions = chatSessionRepository.findCommentsById(userId);
|
|
|
|
if (sessions.size() == 0) {
|
|
|
|
// 初始化欢迎语
|
|
|
|
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
|
|
|
|
List<ChatSession.Message> history = Collections.singletonList(message);
|
|
|
|
session = new ChatSession();
|
|
|
|
session.setId(userId);
|
|
|
|
session.setCode(PK.CHAT_SESSION);
|
|
|
|
session.setTitle("评语");
|
|
|
|
session.setUserId(userId);
|
|
|
|
session.setCreateTime(Instant.now().toEpochMilli());
|
|
|
|
session.setUpdateTime(Instant.now().toEpochMilli());
|
|
|
|
session.setHistory(history);
|
|
|
|
chatSessionRepository.save(session);
|
|
|
|
}else {
|
|
|
|
session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(userId), "该会话不存在");
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
|
|
|
|
if (!session.getUserId().equals(userId)) {
|
|
|
|
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
|
|
|
|
}*/
|
|
|
|
|
|
|
|
SseEmitter sseEmitter = new SseEmitter(-1L);
|
|
|
|
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
|
|
|
|
// open 回调
|
|
|
|
listener.setOnOpen((s) -> {
|
|
|
|
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
|
|
|
|
log.info("callback: ws open event emmit");
|
|
|
|
});
|
|
|
|
// 对话完成的回调
|
|
|
|
listener.setOnComplete((s) -> {
|
|
|
|
log.info("callback: ws complete event emmit");
|
|
|
|
SseHelper.send(sseEmitter, "[DONE]");
|
|
|
|
// 处理完成后的事件: 保存消息记录, 缓存更改
|
|
|
|
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
|
|
|
|
HistoryCache.updateContext(sessionId, message);
|
|
|
|
CosmosPatchOperations options = CosmosPatchOperations.create()
|
|
|
|
.replace("/updateTime", Instant.now().toEpochMilli())
|
|
|
|
.add("/history/-", message);
|
|
|
|
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
|
|
|
|
});
|
|
|
|
// 错误的回调
|
|
|
|
listener.setOnError((s) -> {
|
|
|
|
log.error("callback: ws error, info: " + s);
|
|
|
|
// 返还积分
|
|
|
|
});
|
|
|
|
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
|
|
|
|
SparkChatRequestParam requestParam = SparkChatRequestParam
|
|
|
|
.builder()
|
|
|
|
.uid(userId)
|
|
|
|
.chatId(sessionId)
|
|
|
|
.messageList(messageList)
|
|
|
|
.build();
|
|
|
|
sparkGptClient.streamChatCompletion(requestParam, listener);
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
List<SparkChatRequestParam.Message> fetchContext(String sessionId, String prompt) {
|
|
|
|
List<ChatSession.Message> context = HistoryCache.getContext(sessionId);
|
|
|
|
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
|
|
|
|
// 暂未缓存,从数据库拉取
|
|
|
|
if (ObjectUtils.isEmpty(context)) {
|
|
|
|
context = chatSessionRepository.findLatestMessage(sessionId);
|
|
|
|
|
|
|
|
if (ObjectUtils.isNotEmpty(context)) {
|
|
|
|
HistoryCache.putContext(sessionId, context);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// convert DB Message to Spark Message
|
|
|
|
context.forEach(item -> {
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
|
|
|
|
});
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
|
|
|
|
return paramMessages;
|
|
|
|
}
|
|
|
|
}
|