|
|
|
package cn.teammodel.service.impl;
|
|
|
|
|
|
|
|
import cn.teammodel.ai.SparkGptClient;
|
|
|
|
import cn.teammodel.ai.SseHelper;
|
|
|
|
import cn.teammodel.ai.cache.HistoryCache;
|
|
|
|
import cn.teammodel.ai.domain.SparkChatRequestParam;
|
|
|
|
import cn.teammodel.ai.listener.SparkGptStreamListener;
|
|
|
|
import cn.teammodel.common.ErrorCode;
|
|
|
|
import cn.teammodel.common.PK;
|
|
|
|
import cn.teammodel.config.exception.ServiceException;
|
|
|
|
import cn.teammodel.dao.ChatSessionRepository;
|
|
|
|
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
|
|
|
|
import cn.teammodel.model.entity.User;
|
|
|
|
import cn.teammodel.model.entity.ai.ChatSession;
|
|
|
|
import cn.teammodel.security.utils.SecurityUtil;
|
|
|
|
import cn.teammodel.service.ChatMessageService;
|
|
|
|
import cn.teammodel.utils.RepositoryUtil;
|
|
|
|
import com.azure.cosmos.models.CosmosPatchOperations;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.ObjectUtils;
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
|
|
import javax.annotation.Resource;
|
|
|
|
import java.util.ArrayList;
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @author winter
|
|
|
|
* @create 2023-12-18 15:20
|
|
|
|
*/
|
|
|
|
@Service
|
|
|
|
@Slf4j
|
|
|
|
public class ChatMessageServiceImpl implements ChatMessageService {
|
|
|
|
@Resource
|
|
|
|
private SparkGptClient sparkGptClient;
|
|
|
|
@Resource
|
|
|
|
private ChatSessionRepository chatSessionRepository;
|
|
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
|
|
|
|
// 目前仅使用讯飞星火大模型
|
|
|
|
User user = SecurityUtil.getLoginUser();
|
|
|
|
String userId = user.getId();
|
|
|
|
String userPrompt = chatCompletionReqDto.getText();
|
|
|
|
String sessionId = chatCompletionReqDto.getSessionId();
|
|
|
|
|
|
|
|
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
|
|
|
|
if (!session.getUserId().equals(userId)) {
|
|
|
|
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "该会话不存在");
|
|
|
|
}
|
|
|
|
|
|
|
|
SseEmitter sseEmitter = new SseEmitter(-1L);
|
|
|
|
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
|
|
|
|
// open 回调
|
|
|
|
listener.setOnOpen((s) -> {
|
|
|
|
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
|
|
|
|
log.info("callback: ws open event emmit");
|
|
|
|
});
|
|
|
|
// 对话完成的回调
|
|
|
|
listener.setOnComplete((s) -> {
|
|
|
|
log.info("callback: ws complete event emmit");
|
|
|
|
SseHelper.send(sseEmitter, "[DONE]");
|
|
|
|
// 处理完成后的事件: 保存消息记录, 缓存更改
|
|
|
|
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
|
|
|
|
HistoryCache.updateContext(sessionId, message);
|
|
|
|
CosmosPatchOperations options = CosmosPatchOperations.create().add("/history/-", message);
|
|
|
|
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
|
|
|
|
});
|
|
|
|
// 错误的回调
|
|
|
|
listener.setOnError((s) -> {
|
|
|
|
log.error("callback: ws error" );
|
|
|
|
// 返还积分
|
|
|
|
});
|
|
|
|
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
|
|
|
|
SparkChatRequestParam requestParam = SparkChatRequestParam
|
|
|
|
.builder()
|
|
|
|
.uid(userId)
|
|
|
|
.chatId(sessionId)
|
|
|
|
.messageList(messageList)
|
|
|
|
.build();
|
|
|
|
sparkGptClient.streamChatCompletion(requestParam, listener);
|
|
|
|
return sseEmitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
List<SparkChatRequestParam.Message> fetchContext(String sessionId, String prompt) {
|
|
|
|
List<ChatSession.Message> context = HistoryCache.getContext(sessionId);
|
|
|
|
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
|
|
|
|
// 暂未缓存,从数据库拉取
|
|
|
|
if (ObjectUtils.isEmpty(context)) {
|
|
|
|
context = chatSessionRepository.findLatestMessage(sessionId);
|
|
|
|
|
|
|
|
if (ObjectUtils.isNotEmpty(context)) {
|
|
|
|
HistoryCache.putContext(sessionId, context);
|
|
|
|
// convert DB Message to Spark Message
|
|
|
|
context.forEach(item -> {
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
|
|
|
|
return paramMessages;
|
|
|
|
}
|
|
|
|
}
|