feat: 新增对话历史上下文缓存

11111
winter 1 year ago
parent fa03927579
commit fede04c70d

@ -1,6 +1,7 @@
package cn.teammodel.ai;
import cn.hutool.json.JSONUtil;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import lombok.Data;
@ -38,10 +39,13 @@ public class SparkGptClient implements InitializingBean {
*
*/
public void init() {
// 初始化缓存
HistoryCache.init(sparkGptProperties.getCache_timeout(), sparkGptProperties.getCache_context());
// 初始化 authUrl
authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
log.info("[SPARK CHAT] 鉴权 url: {}", this.authUrl);
log.info("[SPARK CHAT] 鉴权 endpoint : {}", this.authUrl);
// 初始化 okHttpClient
this.okHttpClient = new OkHttpClient()
.newBuilder()
.connectTimeout(90, TimeUnit.SECONDS)
@ -59,10 +63,10 @@ public class SparkGptClient implements InitializingBean {
Request request = new Request.Builder().url(authUrl).build();
// 设置请求参数
listener.setRequestJson(param.toJsonParams());
log.info("请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
log.info("[SPARK CHAT] 请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
okHttpClient.newWebSocket(request, listener);
} catch (Exception e) {
log.error("Spark AI 请求异常: {}", e.getMessage());
log.error("[SPARK CHAT] Spark AI 请求异常: {}", e.getMessage());
e.printStackTrace();
}
}
@ -90,7 +94,7 @@ public class SparkGptClient implements InitializingBean {
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
} catch (Exception e) {
log.error("生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
log.error("[SPARK CHAT] 生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
throw new RuntimeException(e);
}

@ -17,4 +17,12 @@ public class SparkGptProperties {
private String appId;
private String apiKey;
private String apiSecret;
/**
*
*/
private Long cache_timeout;
/**
*
*/
private Integer cache_context;
}

@ -0,0 +1,62 @@
package cn.teammodel.ai.cache;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.collection.ListUtil;
import cn.teammodel.model.entity.ai.ChatSession.Message;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import java.util.List;
/**
*
* @author winter
* @create 2023-12-20 11:02
*/
@Slf4j
@SuppressWarnings("unchecked")
public class HistoryCache {
private static TimedCache<Object, Object> HISTORY;
private static Integer contextSize = 3;
/**
*
*/
public static void init(Long timeout, Integer contextNum) {
contextSize = contextNum;
HISTORY = CacheUtil.newTimedCache(timeout);
// 一分钟清理一次
HISTORY.schedulePrune(60 * 1000);
}
public static List<Message> getContext(String sessionId) {
return (List<Message>) HISTORY.get(sessionId);
}
public static void putContext(String sessionId, List<Message> context) {
HISTORY.put(sessionId, context);
}
public static void removeContext(String sessionId) {
HISTORY.remove(sessionId);}
/**
* , contextSize
*/
public static void updateContext(String sessionId, Message message) {
List<Message> messages = (List<Message>)HISTORY.get(sessionId);
if (ObjectUtils.isEmpty(messages)) {
List<Message> context = ListUtil.of(message);
HISTORY.put(sessionId, context);
} else if (messages.size() >= contextSize) {
// 队列
messages.remove(0);
messages.add(message);
} else {
messages.add(message);
}
}
}

@ -34,7 +34,9 @@ public class SparkChatRequestParam {
//从k个候选中随机选择⼀个⾮等概率
@Default
private Integer top_k = 4;
//用于关联用户会话
/**
* (sessionId)
*/
private String chatId;
private List<Message> messageList;

@ -18,4 +18,7 @@ public interface ChatSessionRepository extends CosmosRepository<ChatSession, Str
@Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.userId = @userId")
List<ChatSession> findByUserId(String userId);
@Query("SELECT value ARRAY_SLICE(c.history, -3) FROM c where c.id = @sessionId and c.code = 'ChatSession'")
List<ChatSession.Message> findLatestMessage(String sessionId);
}

@ -6,7 +6,7 @@ import javax.validation.constraints.NotBlank;
@Data
public class ChatCompletionReqDto {
private Long sessionId;
private String sessionId;
/**
*
*/

@ -45,19 +45,11 @@ public class ChatSession extends BaseItem {
private Integer cost;
private Long createTime;
public static Message ofUserText(String userText) {
public static Message of(String userText, String gptText) {
Message message = new Message();
message.setId(UUID.randomUUID().toString());
message.setCost(0);
message.setUserText(userText);
message.setCreateTime(Instant.now().toEpochMilli());
return message;
}
public static Message ofGptText(String gptText) {
Message message = new Message();
message.setId(UUID.randomUUID().toString());
message.setCost(0);
message.setGptText(gptText);
message.setCreateTime(Instant.now().toEpochMilli());
return message;

@ -2,18 +2,27 @@ package cn.teammodel.service.impl;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.dao.ChatSessionRepository;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatMessageService;
import com.google.common.collect.Lists;
import cn.teammodel.utils.RepositoryUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
/**
@ -25,16 +34,25 @@ import java.util.List;
public class ChatMessageServiceImpl implements ChatMessageService {
@Resource
private SparkGptClient sparkGptClient;
@Resource
private ChatSessionRepository chatSessionRepository;
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
// 目前仅使用讯飞星火大模型
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
String text = chatCompletionReqDto.getText();
String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId();
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "该会话不存在");
}
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
@ -44,24 +62,46 @@ public class ChatMessageServiceImpl implements ChatMessageService {
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create().add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error" );
// 返还积分
});
// todo: 拉取对话上下文
List<SparkChatRequestParam.Message> messageList = Lists.newArrayList();
messageList.add(SparkChatRequestParam.Message.ofUser(text));
// todo: sessionId
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId("123")
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
List<SparkChatRequestParam.Message> fetchContext(String sessionId, String prompt) {
List<ChatSession.Message> context = HistoryCache.getContext(sessionId);
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
// 暂未缓存,从数据库拉取
if (ObjectUtils.isEmpty(context)) {
context = chatSessionRepository.findLatestMessage(sessionId);
if (ObjectUtils.isNotEmpty(context)) {
HistoryCache.putContext(sessionId, context);
// convert DB Message to Spark Message
context.forEach(item -> {
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
});
}
}
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
return paramMessages;
}
}

@ -38,7 +38,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
// 初始化欢迎语
Message message = Message.ofGptText("你好 " + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
Message message = Message.of("", "你好" + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<Message> history = Collections.singletonList(message);
ChatSession chatSession = new ChatSession();
chatSession.setId(UUID.randomUUID().toString());

@ -331,7 +331,7 @@ public class EvaluationServiceImpl implements EvaluationService {
appraiseRecordRepository.save(record);
} else {
CosmosPatchOperations operations = CosmosPatchOperations.create();
operations.add("/nodes/0", item);
operations.add("/nodes/-", item);
// 表扬 (待改进不会减少表扬数)
if (appraiseTreeNode.isPraise()) {
operations.increment("/praiseCount", 1);

@ -30,6 +30,8 @@ spark:
appId: c49d1e24
apiKey: 6c586e7dd1721ed1bb19bdb573b4ad34
apiSecret: MDU1MTU1Nzg4MDg2ZTJjZWU3MmI4ZGU1
cache_timeout: 1800000 # 30min
cache_context: 3
jwt:

@ -219,7 +219,18 @@ class TeamModelExtensionApplicationTests {
@Test
public void testSelectChatSession() {
System.out.println(chatSessionRepository.findByUserId("1595321354"));
// System.out.println(chatSessionRepository.findByUserId("1595321354"));
// insert message
// ChatSession.Message message = new ChatSession.Message();
// message.setId("0");
// message.setUserText("aaa");
// message.setGptText("bbb");
// message.setCost(0);
// message.setCreateTime(Instant.now().toEpochMilli());
// CosmosPatchOperations options = CosmosPatchOperations.create().add("/history/-", message);
// System.out.println(chatSessionRepository.save("111e90e5-6afd-413b-ae0f-646d957aedf8", PK.of(PK.CHAT_SESSION), ChatSession.class, options));
System.out.println(chatSessionRepository.findLatestMessage("111e90e5-6afd-413b-ae0f-646d957aedf8"));
}

Loading…
Cancel
Save