package cn.teammodel.service.impl; import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SseHelper; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; import cn.teammodel.model.entity.User; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatMessageService; import com.google.common.collect.Lists; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.util.List; /** * @author winter * @create 2023-12-18 15:20 */ @Service @Slf4j public class ChatMessageServiceImpl implements ChatMessageService { @Resource private SparkGptClient sparkGptClient; @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { // 目前仅使用讯飞星火大模型 User user = SecurityUtil.getLoginUser(); String userId = user.getId(); String text = chatCompletionReqDto.getText(); SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录 }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error" ); // 返还积分 }); // todo: 拉取对话上下文 List messageList = Lists.newArrayList(); messageList.add(SparkChatRequestParam.Message.ofUser(text)); // todo: sessionId SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) .chatId("123") .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } }