diff --git a/src/main/java/cn/teammodel/ai/SparkGptClient.java b/src/main/java/cn/teammodel/ai/SparkGptClient.java index 4616e06..af75b02 100644 --- a/src/main/java/cn/teammodel/ai/SparkGptClient.java +++ b/src/main/java/cn/teammodel/ai/SparkGptClient.java @@ -1,6 +1,7 @@ package cn.teammodel.ai; import cn.hutool.json.JSONUtil; +import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import lombok.Data; @@ -38,10 +39,13 @@ public class SparkGptClient implements InitializingBean { * 静态构造对象方法 */ public void init() { + // 初始化缓存 + HistoryCache.init(sparkGptProperties.getCache_timeout(), sparkGptProperties.getCache_context()); + // 初始化 authUrl authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret()); this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://"); - log.info("[SPARK CHAT] 鉴权 url: {}", this.authUrl); - + log.info("[SPARK CHAT] 鉴权 endpoint : {}", this.authUrl); + // 初始化 okHttpClient this.okHttpClient = new OkHttpClient() .newBuilder() .connectTimeout(90, TimeUnit.SECONDS) @@ -59,10 +63,10 @@ public class SparkGptClient implements InitializingBean { Request request = new Request.Builder().url(authUrl).build(); // 设置请求参数 listener.setRequestJson(param.toJsonParams()); - log.info("请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty()); + log.info("[SPARK CHAT] 请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty()); okHttpClient.newWebSocket(request, listener); } catch (Exception e) { - log.error("Spark AI 请求异常: {}", e.getMessage()); + log.error("[SPARK CHAT] Spark AI 请求异常: {}", e.getMessage()); e.printStackTrace(); } } @@ -90,7 +94,7 @@ public class SparkGptClient implements InitializingBean { SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256"); mac.init(spec); } catch (Exception e) { - log.error("生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret); + log.error("[SPARK CHAT] 生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret); throw new RuntimeException(e); } diff --git a/src/main/java/cn/teammodel/ai/SparkGptProperties.java b/src/main/java/cn/teammodel/ai/SparkGptProperties.java index 7d82632..e3719ed 100644 --- a/src/main/java/cn/teammodel/ai/SparkGptProperties.java +++ b/src/main/java/cn/teammodel/ai/SparkGptProperties.java @@ -17,4 +17,12 @@ public class SparkGptProperties { private String appId; private String apiKey; private String apiSecret; + /** + * 单个会话的缓存过期时间 + */ + private Long cache_timeout; + /** + * 历史上下文数 + */ + private Integer cache_context; } diff --git a/src/main/java/cn/teammodel/ai/cache/HistoryCache.java b/src/main/java/cn/teammodel/ai/cache/HistoryCache.java new file mode 100644 index 0000000..6bbe392 --- /dev/null +++ b/src/main/java/cn/teammodel/ai/cache/HistoryCache.java @@ -0,0 +1,62 @@ +package cn.teammodel.ai.cache; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.collection.ListUtil; +import cn.teammodel.model.entity.ai.ChatSession.Message; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ObjectUtils; + +import java.util.List; + +/** + * 聊天记录上下文的缓存 + * @author winter + * @create 2023-12-20 11:02 + */ +@Slf4j +@SuppressWarnings("unchecked") +public class HistoryCache { + private static TimedCache HISTORY; + private static Integer contextSize = 3; + + /** + * 初始化缓存 + */ + public static void init(Long timeout, Integer contextNum) { + contextSize = contextNum; + HISTORY = CacheUtil.newTimedCache(timeout); + // 一分钟清理一次 + HISTORY.schedulePrune(60 * 1000); + } + + public static List getContext(String sessionId) { + return (List) HISTORY.get(sessionId); + } + + public static void putContext(String sessionId, List context) { + HISTORY.put(sessionId, context); + } + + public static void removeContext(String sessionId) { + HISTORY.remove(sessionId);} + + /** + * 更新上下文, 保证上下文的数量在 contextSize 之内 + */ + public static void updateContext(String sessionId, Message message) { + List messages = (List)HISTORY.get(sessionId); + + if (ObjectUtils.isEmpty(messages)) { + List context = ListUtil.of(message); + HISTORY.put(sessionId, context); + } else if (messages.size() >= contextSize) { + // 队列 + messages.remove(0); + messages.add(message); + } else { + messages.add(message); + } + } + +} diff --git a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java index 1a12831..8a450ea 100644 --- a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java +++ b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java @@ -34,7 +34,9 @@ public class SparkChatRequestParam { //从k个候选中随机选择⼀个(⾮等概率) @Default private Integer top_k = 4; - //用于关联用户会话 + /** + * 用于关联用户会话 (sessionId) + */ private String chatId; private List messageList; diff --git a/src/main/java/cn/teammodel/dao/ChatSessionRepository.java b/src/main/java/cn/teammodel/dao/ChatSessionRepository.java index 013e4f7..8f6f549 100644 --- a/src/main/java/cn/teammodel/dao/ChatSessionRepository.java +++ b/src/main/java/cn/teammodel/dao/ChatSessionRepository.java @@ -18,4 +18,7 @@ public interface ChatSessionRepository extends CosmosRepository findByUserId(String userId); + + @Query("SELECT value ARRAY_SLICE(c.history, -3) FROM c where c.id = @sessionId and c.code = 'ChatSession'") + List findLatestMessage(String sessionId); } diff --git a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java index c96d4fe..e4a0ff4 100644 --- a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java +++ b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java @@ -6,7 +6,7 @@ import javax.validation.constraints.NotBlank; @Data public class ChatCompletionReqDto { - private Long sessionId; + private String sessionId; /** * 预设的会话面具 */ diff --git a/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java index dfb5542..fcda6be 100644 --- a/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java +++ b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java @@ -45,19 +45,11 @@ public class ChatSession extends BaseItem { private Integer cost; private Long createTime; - public static Message ofUserText(String userText) { + public static Message of(String userText, String gptText) { Message message = new Message(); message.setId(UUID.randomUUID().toString()); message.setCost(0); message.setUserText(userText); - message.setCreateTime(Instant.now().toEpochMilli()); - return message; - } - - public static Message ofGptText(String gptText) { - Message message = new Message(); - message.setId(UUID.randomUUID().toString()); - message.setCost(0); message.setGptText(gptText); message.setCreateTime(Instant.now().toEpochMilli()); return message; diff --git a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java index db65933..ab6271d 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java @@ -2,18 +2,27 @@ package cn.teammodel.service.impl; import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SseHelper; +import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; +import cn.teammodel.common.ErrorCode; +import cn.teammodel.common.PK; +import cn.teammodel.config.exception.ServiceException; +import cn.teammodel.dao.ChatSessionRepository; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; import cn.teammodel.model.entity.User; +import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatMessageService; -import com.google.common.collect.Lists; +import cn.teammodel.utils.RepositoryUtil; +import com.azure.cosmos.models.CosmosPatchOperations; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ObjectUtils; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import java.util.ArrayList; import java.util.List; /** @@ -25,16 +34,25 @@ import java.util.List; public class ChatMessageServiceImpl implements ChatMessageService { @Resource private SparkGptClient sparkGptClient; + @Resource + private ChatSessionRepository chatSessionRepository; + @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { // 目前仅使用讯飞星火大模型 User user = SecurityUtil.getLoginUser(); String userId = user.getId(); - String text = chatCompletionReqDto.getText(); + String userPrompt = chatCompletionReqDto.getText(); + String sessionId = chatCompletionReqDto.getSessionId(); + + ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在"); + if (!session.getUserId().equals(userId)) { + throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "该会话不存在"); + } + SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); - // open 回调 listener.setOnOpen((s) -> { // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) @@ -44,24 +62,46 @@ public class ChatMessageServiceImpl implements ChatMessageService { listener.setOnComplete((s) -> { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); - // 处理完成后的事件: 保存消息记录 + // 处理完成后的事件: 保存消息记录, 缓存更改 + ChatSession.Message message = ChatSession.Message.of(userPrompt, s); + HistoryCache.updateContext(sessionId, message); + CosmosPatchOperations options = CosmosPatchOperations.create().add("/history/-", message); + chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); }); // 错误的回调 listener.setOnError((s) -> { log.error("callback: ws error" ); // 返还积分 }); - // todo: 拉取对话上下文 - List messageList = Lists.newArrayList(); - messageList.add(SparkChatRequestParam.Message.ofUser(text)); - // todo: sessionId + List messageList = fetchContext(sessionId, userPrompt); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) - .chatId("123") + .chatId(sessionId) .messageList(messageList) .build(); sparkGptClient.streamChatCompletion(requestParam, listener); return sseEmitter; } + + List fetchContext(String sessionId, String prompt) { + List context = HistoryCache.getContext(sessionId); + List paramMessages = new ArrayList<>(); + // 暂未缓存,从数据库拉取 + if (ObjectUtils.isEmpty(context)) { + context = chatSessionRepository.findLatestMessage(sessionId); + + if (ObjectUtils.isNotEmpty(context)) { + HistoryCache.putContext(sessionId, context); + // convert DB Message to Spark Message + context.forEach(item -> { + paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText())); + paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText())); + }); + } + } + + paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt)); + return paramMessages; + } } diff --git a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java index 2f272b7..4ffd4ed 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java @@ -38,7 +38,7 @@ public class ChatSessionServiceImpl implements ChatSessionService { User user = SecurityUtil.getLoginUser(); String userId = user.getId(); // 初始化欢迎语 - Message message = Message.ofGptText("你好 " + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + Message message = Message.of("", "你好" + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); List history = Collections.singletonList(message); ChatSession chatSession = new ChatSession(); chatSession.setId(UUID.randomUUID().toString()); diff --git a/src/main/java/cn/teammodel/service/impl/EvaluationServiceImpl.java b/src/main/java/cn/teammodel/service/impl/EvaluationServiceImpl.java index db26386..fb11eb8 100644 --- a/src/main/java/cn/teammodel/service/impl/EvaluationServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/EvaluationServiceImpl.java @@ -331,7 +331,7 @@ public class EvaluationServiceImpl implements EvaluationService { appraiseRecordRepository.save(record); } else { CosmosPatchOperations operations = CosmosPatchOperations.create(); - operations.add("/nodes/0", item); + operations.add("/nodes/-", item); // 表扬 (待改进不会减少表扬数) if (appraiseTreeNode.isPraise()) { operations.increment("/praiseCount", 1); diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 7ab4ee1..e505ad2 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -30,6 +30,8 @@ spark: appId: c49d1e24 apiKey: 6c586e7dd1721ed1bb19bdb573b4ad34 apiSecret: MDU1MTU1Nzg4MDg2ZTJjZWU3MmI4ZGU1 + cache_timeout: 1800000 # 30min + cache_context: 3 jwt: diff --git a/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java b/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java index 8d0dfb8..492bb13 100644 --- a/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java +++ b/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java @@ -219,7 +219,18 @@ class TeamModelExtensionApplicationTests { @Test public void testSelectChatSession() { - System.out.println(chatSessionRepository.findByUserId("1595321354")); +// System.out.println(chatSessionRepository.findByUserId("1595321354")); + // insert message +// ChatSession.Message message = new ChatSession.Message(); +// message.setId("0"); +// message.setUserText("aaa"); +// message.setGptText("bbb"); +// message.setCost(0); +// message.setCreateTime(Instant.now().toEpochMilli()); +// CosmosPatchOperations options = CosmosPatchOperations.create().add("/history/-", message); +// System.out.println(chatSessionRepository.save("111e90e5-6afd-413b-ae0f-646d957aedf8", PK.of(PK.CHAT_SESSION), ChatSession.class, options)); + + System.out.println(chatSessionRepository.findLatestMessage("111e90e5-6afd-413b-ae0f-646d957aedf8")); }