You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
10 KiB

package cn.teammodel.service.impl;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.deepseek.DeepSeekClient;
import cn.teammodel.common.PK;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.model.entity.ai.DeepSeekSession;
import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.repository.DeepSeekRepository;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.DeepSeekService;
import cn.teammodel.service.DeepSeekSessionService;
import cn.teammodel.utils.RepositoryUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* 访DeepSeek
*/
@Service
@Slf4j
public class DeepSeekServiceImpl implements DeepSeekService {
@Resource
private DeepSeekSessionService deepSeekService;
@Resource
private DeepSeekRepository deepSeekRepository;
@Resource
private ChatSessionRepository chatSessionRepository;
private final ExecutorService executorService = Executors.newCachedThreadPool();
private final ObjectMapper objectMapper = new ObjectMapper();
/**
*
* @param message
* @return
*/
@Override
public DeepSeekChatResponse ChatAsk(ChatReqDto message) {
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setMax_tokens(1024);
//开始时间
long startTime = System.currentTimeMillis();
//发起请求
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = SendRequest(requestBody);
//结束时间
long endTime = System.currentTimeMillis();
//思考耗时 秒
response.setWasteTime((endTime-startTime)/1000);
if (response.getCode() == 200){
DeepSeekMessage savaMessage = new DeepSeekMessage();
savaMessage.setId(response.getId());
savaMessage.setUserText(message.getContent());
savaMessage.setAiText(response.getChoices().get(0).getMessage().getContent());
savaMessage.setRole(response.getChoices().get(0).getMessage().getRole());
savaMessage.setCreateTime(response.getCreated());
savaMessage.setFinish_reason(response.getChoices().get(0).getFinish_reason());
DeepSeekSession session = new DeepSeekSession();
if (message.getSessionId() != null){
session = deepSeekRepository.findSessionByIdAndCode(message.getSessionId(), PK.CHAT_SESSION);
UpdateSession(message, session, savaMessage, response);
}else {
UpdateSession(message, session, savaMessage, response);
}
}
return response;
}
@Override
public SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message) {
SseEmitter se = new SseEmitter(-1L);
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setMax_tokens(1024);
requestBody.setStream(true);
if (requestBody.getStream()){
se = DeepSeekClient.SendRequestsEmitter(requestBody);
}else {
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
}
return se;
}
/**
*
* @param chatCompletionReqDto
* @return
*/
@Override
public SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto) {
SseEmitter sseEmitter = new SseEmitter(-1L);
StringBuilder strContent = new StringBuilder();
executorService.execute(()-> {
try {
log.info("流式回答开始,问题:{}", chatCompletionReqDto.getText());
try (CloseableHttpClient client = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(DeepSeekClient.API_Url);
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Authorization", "Bearer " + DeepSeekClient.API_Key);
Map<String, Object> question = new HashMap<>();
question.put("role", "user");
question.put("content", chatCompletionReqDto.getText());
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", DeepSeekClient.API_Model);
requestMap.put("messages", Collections.singletonList(question));
requestMap.put("stream", true);
String requestBody = objectMapper.writeValueAsString(requestMap);
httpPost.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8));
StringBuilder responseBody = new StringBuilder();
try (CloseableHttpResponse response = client.execute(httpPost);
BufferedReader reader = new BufferedReader(
new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String jsonData = line.substring(6);
if ("[DONE]".equals(jsonData)) {
sseEmitter.send("[DONE]");
// 会话完成,更新历史会话记录
ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel());
HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options);
break;
}
JsonNode node = objectMapper.readTree(jsonData);
String content = node.path("choices")
.path(0)
.path("delta")
.path("content")
.asText("");
if (!content.isEmpty()) {
responseBody.append(content);
strContent.append(content);
sseEmitter.send(content);
}
}
}
log.info("流式回答结束,{}",question);
sseEmitter.complete();
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
});
return sseEmitter;
}
//region 辅助方法
/**
* /
* @param message
* @param session
* @param savaMessage
* @param response
*/
private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, DeepSeekChatResponse response) {
if (session.getId() == null){
List<DeepSeekMessage> history = Collections.singletonList(savaMessage);
String userId = SecurityUtil.getLoginUser().getId();
session.setId(UUID.randomUUID().toString());
session.setCode(PK.DEEPSEEK_SESSION);
session.setTitle("新对话");
session.setUserId(userId);
session.setModel(response.getModel());
session.setCreateTime(Instant.now().toEpochMilli());
session.setUpdateTime(Instant.now().toEpochMilli());
session.setHistory(history);
session = deepSeekRepository.save(session);
}else {
session.getHistory().add(savaMessage);
deepSeekService.updateSession(session, message.getSessionId());
}
}
//endregion
}