package cn.teammodel.service.impl; import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SseHelper; import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import cn.teammodel.common.ErrorCode; import cn.teammodel.common.PK; import cn.teammodel.config.exception.ServiceException; import cn.teammodel.model.entity.ai.ChatApp; import cn.teammodel.repository.ChatAppRepository; import cn.teammodel.repository.ChatSessionRepository; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; import cn.teammodel.model.entity.User; import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatMessageService; import cn.teammodel.utils.RepositoryUtil; import com.azure.cosmos.models.CosmosPatchOperations; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.time.Instant; import java.util.ArrayList; import java.util.List; /** * @author winter * @create 2023-12-18 15:20 */ @Service @Slf4j public class ChatMessageServiceImpl implements ChatMessageService { @Resource private SparkGptClient sparkGptClient; @Resource private ChatSessionRepository chatSessionRepository; @Resource private ChatAppRepository chatAppRepository; @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { // 目前仅使用讯飞星火大模型 String appId = chatCompletionReqDto.getAppId(); SseEmitter sseEmitter; if (StringUtils.isEmpty(appId)) { sseEmitter = completionBySession(chatCompletionReqDto); } else { sseEmitter = completionByApp(chatCompletionReqDto, false); } return sseEmitter; } /** * 面具模式(暂时不存储聊天记录) */ private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) { String appId = chatCompletionReqDto.getAppId(); String userPrompt = chatCompletionReqDto.getText(); User user = SecurityUtil.getLoginUser(); String userId = user.getId(); String schoolId = user.getSchoolId(); // 查询 appId 获取 prompt // 通过 prompt 和 userprompt 生成结果 // 直接返回 ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在"); // 检验 app 是否可以被该用户使用 if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用"); } String appPrompt = chatApp.getPrompt(); SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: if (!justApi) { // 保存消息记录, 缓存更改 } }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error, info: " + s); // 返还积分 }); List messageList = new ArrayList<>(); messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt)); messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt)); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId(appId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } /** * 会话模式 */ private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto) { // User user = SecurityUtil.getLoginUser(); // String userId = user.getId(); String userId = "1595321354"; String userPrompt = chatCompletionReqDto.getText(); String sessionId = chatCompletionReqDto.getSessionId(); ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在"); if (!session.getUserId().equals(userId)) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在"); } SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录, 缓存更改 ChatSession.Message message = ChatSession.Message.of(userPrompt, s); HistoryCache.updateContext(sessionId, message); CosmosPatchOperations options = CosmosPatchOperations.create() .replace("/updateTime", Instant.now().toEpochMilli()) .add("/history/-", message); chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error, info: " + s); // 返还积分 }); List messageList = fetchContext(sessionId, userPrompt); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId(sessionId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } List fetchContext(String sessionId, String prompt) { List context = HistoryCache.getContext(sessionId); List paramMessages = new ArrayList<>(); // 暂未缓存,从数据库拉取 if (ObjectUtils.isEmpty(context)) { context = chatSessionRepository.findLatestMessage(sessionId); if (ObjectUtils.isNotEmpty(context)) { HistoryCache.putContext(sessionId, context); } } // convert DB Message to Spark Message context.forEach(item -> { paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText())); paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText())); }); paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt)); return paramMessages; } }