聊天评语设置接口

11111
PL 5 months ago
parent 523dfa47b6
commit e8dfcf7f9f

@ -10,9 +10,11 @@ import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatAppService; import cn.teammodel.service.ChatAppService;
import cn.teammodel.service.ChatMessageService; import cn.teammodel.service.ChatMessageService;
import cn.teammodel.service.ChatSessionService; import cn.teammodel.service.ChatSessionService;
import io.jsonwebtoken.Claims;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -147,9 +149,22 @@ public class AiController {
return R.success("删除应用成功"); return R.success("删除应用成功");
} }
@PostMapping("chat/comments")
@ApiOperation("设置评语")
public SseEmitter chatComments(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
/*
Authentication user0 = SecurityUtil.getAuthentication();
Object user01 = SecurityUtil.getAuthentication().getPrincipal();
TmdUserDetail user02 = (TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal();
Claims user03 = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims();
String user04 = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
*/
//String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
// 获取getClaims时为空
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getUser().getId();
String userName = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getUser().getName();
return chatMessageService.chatComments(chatCompletionReqDto, userId, userName);
}
} }

@ -22,4 +22,7 @@ public interface ChatSessionRepository extends CosmosRepository<ChatSession, Str
@Query("SELECT value ARRAY_SLICE(c.history, -3) FROM c where c.id = @sessionId and c.code = 'ChatSession'") @Query("SELECT value ARRAY_SLICE(c.history, -3) FROM c where c.id = @sessionId and c.code = 'ChatSession'")
List<ChatSession.Message> findLatestMessage(String sessionId); List<ChatSession.Message> findLatestMessage(String sessionId);
@Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.id = @userId")
List<ChatSession> findCommentsById(String userId);
} }

@ -12,4 +12,11 @@ public interface ChatMessageService {
* AI * AI
*/ */
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId); SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId);
/**
* AI
* @param chatCompletionReqDto
* @return
*/
SseEmitter chatComments(ChatCompletionReqDto chatCompletionReqDto, String userId,String userName);
} }

@ -1,5 +1,6 @@
package cn.teammodel.service.impl; package cn.teammodel.service.impl;
import cn.hutool.core.lang.UUID;
import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper; import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.cache.HistoryCache;
@ -27,6 +28,7 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
/** /**
@ -57,6 +59,25 @@ public class ChatMessageServiceImpl implements ChatMessageService {
return sseEmitter; return sseEmitter;
} }
/**
*
* @param chatCompletionReqDto
* @param userId
* @return
*/
@Override
public SseEmitter chatComments(ChatCompletionReqDto chatCompletionReqDto, String userId,String userName) {
// 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = commentsBySession(chatCompletionReqDto, userId,userName);
} else {
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
return sseEmitter;
}
/** /**
* () * ()
*/ */
@ -158,6 +179,73 @@ public class ChatMessageServiceImpl implements ChatMessageService {
return sseEmitter; return sseEmitter;
} }
/**
*
*/
private SseEmitter commentsBySession(ChatCompletionReqDto chatCompletionReqDto, String userId,String userName) {
String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId();
ChatSession session = null;
List<ChatSession> sessions = chatSessionRepository.findCommentsById(userId);
if (sessions.size() == 0) {
// 初始化欢迎语
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<ChatSession.Message> history = Collections.singletonList(message);
session = new ChatSession();
session.setId(userId);
session.setCode(PK.CHAT_SESSION);
session.setTitle("评语");
session.setUserId(userId);
session.setCreateTime(Instant.now().toEpochMilli());
session.setUpdateTime(Instant.now().toEpochMilli());
session.setHistory(history);
chatSessionRepository.save(session);
}else {
session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(userId), "该会话不存在");
}
/*
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(sessionId), "该会话不存在");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "该会话不存在");
}*/
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
List<SparkChatRequestParam.Message> fetchContext(String sessionId, String prompt) { List<SparkChatRequestParam.Message> fetchContext(String sessionId, String prompt) {
List<ChatSession.Message> context = HistoryCache.getContext(sessionId); List<ChatSession.Message> context = HistoryCache.getContext(sessionId);
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>(); List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();

Loading…
Cancel
Save