You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

356 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package cn.teammodel.service.impl;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.deepseek.DeepSeekClient;
import cn.teammodel.common.PK;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.model.entity.ai.DeepSeekSession;
import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.repository.DeepSeekRepository;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.DeepSeekService;
import cn.teammodel.service.DeepSeekSessionService;
import cn.teammodel.utils.RepositoryUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okio.BufferedSource;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* 描述访问DeepSeek方法
*/
@Service
@Slf4j
public class DeepSeekServiceImpl implements DeepSeekService {
@Resource
private DeepSeekSessionService deepSeekService;
@Resource
private DeepSeekRepository deepSeekRepository;
@Resource
private ChatSessionRepository chatSessionRepository;
private final ExecutorService executorService = Executors.newCachedThreadPool();
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 提问
* @param message
* @return
*/
@Override
public DeepSeekChatResponse ChatAsk(ChatReqDto message) {
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setMax_tokens(1024);
//开始时间
long startTime = System.currentTimeMillis();
//发起请求
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = SendRequest(requestBody);
//结束时间
long endTime = System.currentTimeMillis();
//思考耗时 秒
response.setWasteTime((endTime-startTime)/1000);
if (response.getCode() == 200){
DeepSeekMessage savaMessage = new DeepSeekMessage();
savaMessage.setId(response.getId());
savaMessage.setUserText(message.getContent());
savaMessage.setAiText(response.getChoices().get(0).getMessage().getContent());
savaMessage.setRole(response.getChoices().get(0).getMessage().getRole());
savaMessage.setCreateTime(response.getCreated());
savaMessage.setFinish_reason(response.getChoices().get(0).getFinish_reason());
DeepSeekSession session = new DeepSeekSession();
if (message.getSessionId() != null){
session = deepSeekRepository.findSessionByIdAndCode(message.getSessionId(), PK.CHAT_SESSION);
UpdateSession(message, session, savaMessage, response);
}else {
UpdateSession(message, session, savaMessage, response);
}
}
return response;
}
@Override
public SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message) {
SseEmitter se = new SseEmitter(-1L);
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setMax_tokens(1024);
requestBody.setStream(true);
if (requestBody.getStream()){
se = DeepSeekClient.SendRequestsEmitter(requestBody);
}else {
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
}
return se;
}
/**
* 提问 流式回答
* @param chatCompletionReqDto
* @return
*/
@Override
public SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto) {
SseEmitter sseEmitter = new SseEmitter(-1L);
StringBuilder strContent = new StringBuilder();
StringBuilder strReasoning = new StringBuilder();
//executorService.execute(()-> {
try {
log.info("流式回答开始,问题:{}", chatCompletionReqDto.getText());
try (CloseableHttpClient client = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(DeepSeekClient.API_Url);
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Authorization", "Bearer " + DeepSeekClient.API_Key);
List<Map<String, Object>> messages = new ArrayList<>();
//助手的角色
Map<String, Object> systemMessage = new HashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", "你是一个教师助手");
messages.add(systemMessage);
//用户的消息
Map<String, Object> userMessage = new HashMap<>();
userMessage.put("role", "user");
userMessage.put("content", chatCompletionReqDto.getText());
messages.add(userMessage);
//向DeepSeek发送请求
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", chatCompletionReqDto.getModel());
requestMap.put("messages", messages);
requestMap.put("stream", true);
requestMap.put("max_tokens", 1024);
String requestBody = objectMapper.writeValueAsString(requestMap);
httpPost.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8));
try (CloseableHttpResponse response = client.execute(httpPost);
BufferedReader reader = new BufferedReader(
new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String jsonData = line.substring(6);
if ("[DONE]".equals(jsonData)) {
sseEmitter.send("[DONE]");
//断开连接
sseEmitter.complete();
// 会话完成,更新历史会话记录
if (chatCompletionReqDto.isSave()){
ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel(),strReasoning.toString());
HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options);
}
break;
}
JsonNode node = objectMapper.readTree(jsonData);
String content = node.path("choices")
.path(0)
.path("delta")
.path("content")
.asText("");
//推理过程
String reasoning_content = node.path("choices")
.path(0)
.path("delta")
.path("reasoning_content")
.asText("");
if (!content.isEmpty()) {
strContent.append(content);
sseEmitter.send(content);
}
if (!reasoning_content.isEmpty()) {
strReasoning.append(reasoning_content);
sseEmitter.send("reasoning:"+ reasoning_content);
}
}
}
log.info("流式回答结束,{}",messages);
//sseEmitter.complete();
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
//});
return sseEmitter;
}
/**
* 深度思考
* @param chatCompletionReqDto
* @return
*/
@Override
public SseEmitter ReasonerChatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
SseEmitter sseEmitter = new SseEmitter(-1L);
OkHttpClient client = new OkHttpClient.Builder()
.readTimeout(30, TimeUnit.SECONDS)
.build();
MediaType JSON = MediaType.parse("application/json; charset=utf-8");
// 构建请求体
// String requestBody = "{"
// + "\"model\": \"deepseek-chat\","
// + "\"messages\": [{\"role\": \"user\", \"content\": \"请介绍一下成都\"}],"
// + "\"stream\": true,"
// + "\"temperature\": 0.7"
// + "}";
Map<String, Object> question = new HashMap<>();
question.put("role", "user");
question.put("content", chatCompletionReqDto.getText());
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", chatCompletionReqDto.getModel());
requestMap.put("messages", Collections.singletonList(question));
requestMap.put("stream", true);
String requestBody = null;
try {
requestBody = objectMapper.writeValueAsString(requestMap);
} catch (JsonProcessingException e) {
log.error("处理用户转换问题出错", e);
sseEmitter.completeWithError(e);
}
Request request = new Request.Builder()
.url(DeepSeekClient.API_Url)
.post(RequestBody.create(JSON, requestBody))
.addHeader("Authorization", "Bearer " + DeepSeekClient.API_Key)
.addHeader("Accept", "text/event-stream")
.build();
try(Response response = client.newCall(request).execute()){
if (response.isSuccessful() && response.body() != null) {
try (ResponseBody body = response.body()) {
if (body != null) {
BufferedSource source = body.source();
while (!source.exhausted()) {
String line = source.readUtf8Line();
if (line != null && line.startsWith("data: ")) {
String json = line.substring(6).trim();
if (json.equals("[DONE]")) {
sseEmitter.send("[DONE]");
break;
}
JsonNode node = objectMapper.readTree(json);
String content = node.path("choices")
.path(0)
.path("delta")
.path("content")
.asText("");
String reasoning_content = node.path("choices")
.path(0)
.path("delta")
.path("reasoning_content")
.asText("");
if (!content.isEmpty()) {
sseEmitter.send(content);
}
if (!reasoning_content.isEmpty()) {
sseEmitter.send("reasoning:"+reasoning_content);
}
}
}
log.info("流式回答结束,{}",requestBody);
sseEmitter.complete();
}
} catch (IOException e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
}
}catch (IOException e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
return sseEmitter;
}
//region 辅助方法
/**
* 新增/更新会话
* @param message
* @param session
* @param savaMessage
* @param response
*/
private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, DeepSeekChatResponse response) {
if (session.getId() == null){
List<DeepSeekMessage> history = Collections.singletonList(savaMessage);
String userId = SecurityUtil.getLoginUser().getId();
session.setId(UUID.randomUUID().toString());
session.setCode(PK.DEEPSEEK_SESSION);
session.setTitle("新对话");
session.setUserId(userId);
session.setModel(response.getModel());
session.setCreateTime(Instant.now().toEpochMilli());
session.setUpdateTime(Instant.now().toEpochMilli());
session.setHistory(history);
session = deepSeekRepository.save(session);
}else {
session.getHistory().add(savaMessage);
deepSeekService.updateSession(session, message.getSessionId());
}
}
//endregion
}