diff --git a/src/main/java/cn/teammodel/ai/SparkGptClient.java b/src/main/java/cn/teammodel/ai/SparkGptClient.java index 0c3d2f2..61f4361 100644 --- a/src/main/java/cn/teammodel/ai/SparkGptClient.java +++ b/src/main/java/cn/teammodel/ai/SparkGptClient.java @@ -6,10 +6,9 @@ import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import lombok.Data; import lombok.extern.slf4j.Slf4j; -import okhttp3.HttpUrl; -import okhttp3.OkHttpClient; -import okhttp3.Request; +import okhttp3.*; import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import javax.annotation.Resource; @@ -69,7 +68,6 @@ public class SparkGptClient implements InitializingBean { e.printStackTrace(); } } - /** * 生成鉴权URL */ @@ -111,6 +109,7 @@ public class SparkGptClient implements InitializingBean { return httpUrl.toString(); } + @Override public void afterPropertiesSet() throws Exception { init(); diff --git a/src/main/java/cn/teammodel/ai/deepseek/DeepSeekClient.java b/src/main/java/cn/teammodel/ai/deepseek/DeepSeekClient.java index 8fb16ad..f976587 100644 --- a/src/main/java/cn/teammodel/ai/deepseek/DeepSeekClient.java +++ b/src/main/java/cn/teammodel/ai/deepseek/DeepSeekClient.java @@ -1,16 +1,29 @@ package cn.teammodel.ai.deepseek; +import cn.teammodel.ai.SparkGptClient; +import cn.teammodel.ai.SseHelper; +import cn.teammodel.ai.cache.HistoryCache; +import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.common.ErrorCode; +import cn.teammodel.common.PK; import cn.teammodel.config.exception.ServiceException; -import cn.teammodel.model.dto.ai.deepseek.ChatRequestOKHttpDto; -import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto; +import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse; import cn.teammodel.model.dto.ai.deepseek.ChatReqDto; +import cn.teammodel.model.entity.ai.ChatSession; +import cn.teammodel.repository.ChatSessionRepository; import cn.teammodel.utils.JsonUtil; +import com.azure.cosmos.models.CosmosPatchOperations; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.type.TypeFactory; import com.google.gson.Gson; +import com.sun.org.apache.bcel.internal.generic.NEW; import lombok.extern.slf4j.Slf4j; import okhttp3.*; +import okio.Buffer; +import okio.BufferedSource; import org.apache.http.HttpEntity; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPost; @@ -19,20 +32,35 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import javax.annotation.Resource; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.HashMap; import java.util.Map; import java.io.IOException; import java.io.InputStream; import java.util.*; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; @Slf4j -public class DeepSeekClient { - private static final String API_Key; - private static final String API_Url; +public class DeepSeekClient { + public static final String API_Key; + public static final String API_Url; public static String API_Model; + + @Resource + private static ChatSessionRepository chatSessionRepository; + + private static final ExecutorService executorService = Executors.newCachedThreadPool(); + private static final ObjectMapper objectMapper = new ObjectMapper(); /** * 读取配置文件 读取key 和url */ @@ -61,7 +89,7 @@ public class DeepSeekClient { msg.add(mssage); //构建请求头 - ChatRequestOKHttpDto requestBody = new ChatRequestOKHttpDto(); + DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto(); requestBody.setModel(API_Model); requestBody.setMessages(msg); requestBody.setTemperature(0); @@ -69,7 +97,7 @@ public class DeepSeekClient { long startTime = System.currentTimeMillis(); //发起请求 - ChatResponseDto response = SendRequests(requestBody); + DeepSeekChatResponse response = SendRequests(requestBody); //Map response = SendRequest(requestBody); Long endTime = System.currentTimeMillis(); //思考耗时 秒 @@ -85,9 +113,9 @@ public class DeepSeekClient { * @param requestBody * @return */ - public static ChatResponseDto SendRequests(ChatRequestOKHttpDto requestBody) + public static DeepSeekChatResponse SendRequests(DeepSeekChatRequestDto requestBody) { - ChatResponseDto chatResponse = new ChatResponseDto(); + DeepSeekChatResponse chatResponse = new DeepSeekChatResponse(); //OkHttpClient client = new OkHttpClient().newBuilder().connectTimeout(60, TimeUnit.SECONDS).build();//设置连接超时时间 1分钟 OkHttpClient client = new OkHttpClient().newBuilder().build();//设置连接超时时间 1分钟 @@ -104,13 +132,36 @@ public class DeepSeekClient { .addHeader("Accept", "application/json") .addHeader("Authorization", "Bearer "+API_Key) .build(); - + ObjectMapper objectMapper = new ObjectMapper(); try(Response response = client.newCall(request).execute()) { if (response.isSuccessful() && response.body() != null) { - String responseBody = response.body().string(); + StringBuilder responseBody = new StringBuilder(); + try (BufferedSource source = response.body().source()) { + Buffer buffer = new Buffer(); + while (source.read(buffer, 2048) != -1) { + // 处理流式数据 + String chunk = buffer.readUtf8(); + if (chunk.startsWith("data:") && !chunk.contains("data: [DONE]")) { + String[] split = chunk.split("data:"); + for (String result : split) { + if (StringUtils.hasLength(result) && StringUtils.hasLength(result.trim())) { + JsonNode jsonNode = objectMapper.readTree(result); + if (!ObjectUtils.isEmpty(jsonNode.get("choices"))) { + JsonNode delta = jsonNode.get("choices").get(0).get("delta"); + log.debug("Delta Content: {}", delta.get("content").asText()); + responseBody.append(delta.get("content").asText()); + } + } + } + } + } + } + + + String responseBody1 = response.body().string(); // 使用 Gson 将 JSON 字符串转换为 MyEntity 对象 Gson gson = new Gson(); - chatResponse = gson.fromJson(responseBody, ChatResponseDto.class); + chatResponse = gson.fromJson(responseBody1, DeepSeekChatResponse.class); // 确保关闭响应体以释放资源 response.body().close(); chatResponse.setCode(200); @@ -132,12 +183,145 @@ public class DeepSeekClient { return chatResponse; } + + /** + * OkHttpClient 方式请求 流式返回 + * @param requestBody + * @return + */ + public static SseEmitter SendRequestsEmitter(DeepSeekChatRequestDto requestBody) + { + SseEmitter sseEmitter = new SseEmitter(-1L); + //OkHttpClient client = new OkHttpClient().newBuilder().connectTimeout(60, TimeUnit.SECONDS).build();//设置连接超时时间 1分钟 + OkHttpClient client = new OkHttpClient().newBuilder().build();// + + MediaType mediaType = MediaType.parse("application/json"); + //String content = "{\n \"messages\": [\n {\n \"content\": \"You are a helpful assistant\",\n \"role\": \"system\"\n },\n {\n \"content\": \"Hi\",\n \"role\": \"user\"\n }\n ],\n \"model\": \"deepseek-chat\",\n \"frequency_penalty\": 0,\n \"max_tokens\": 2048,\n \"presence_penalty\": 0,\n \"response_format\": {\n \"type\": \"text\"\n },\n \"stop\": null,\n \"stream\": false,\n \"stream_options\": null,\n \"temperature\": 1,\n \"top_p\": 1,\n \"tools\": null,\n \"tool_choice\": \"none\",\n \"logprobs\": false,\n \"top_logprobs\": null\n}"; + String content = JsonUtil.convertToJson(requestBody); + + RequestBody body = RequestBody.create(mediaType, content); + Request request = new Request.Builder() + .url(API_Url) + .method("POST", body) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", "application/json") + .addHeader("Authorization", "Bearer "+API_Key) + .build(); + ObjectMapper objectMapper = new ObjectMapper(); + try(Response response = client.newCall(request).execute()) { + if (response.isSuccessful() && response.body() != null) { + StringBuilder responseBody = new StringBuilder(); + try (BufferedSource source = response.body().source()) { + Buffer buffer = new Buffer(); + while (source.read(buffer, 2048) != -1) { + // 处理流式数据 + String chunk = buffer.readUtf8(); + if (chunk.startsWith("data:") && !chunk.contains("data: [DONE]")) { + String[] split = chunk.split("data:"); + for (String result : split) { + if (StringUtils.hasLength(result) && StringUtils.hasLength(result.trim())) { + JsonNode jsonNode = objectMapper.readTree(result); + if (!ObjectUtils.isEmpty(jsonNode.get("choices"))) { + JsonNode delta = jsonNode.get("choices").get(0).get("delta"); + log.debug("Delta Content: {}", delta.get("content").asText()); + sseEmitter.send(delta); + } + } + } + } + } + }catch (IOException e) { + sseEmitter.completeWithError(e); + } + + } else { + sseEmitter.completeWithError(new Exception("请求DeepSeek服务器失败")); + } + } catch (IOException e) { + sseEmitter.completeWithError(e); + } + return sseEmitter; + } + + /** + * HttpClient 方式请求 + * @param chatCompletionReqDto + * @return + */ + public static SseEmitter HttpClientSendRequests(ChatCompletionReqDto chatCompletionReqDto){ + SseEmitter emitter = new SseEmitter(-1L); + List msg = new ArrayList<>(); + msg.add(new ChatReqDto(chatCompletionReqDto.getSessionId(), "user", chatCompletionReqDto.getText())); + //构建请求头 + DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto(); + requestBody.setModel(DeepSeekClient.API_Model); + requestBody.setMessages(msg); + requestBody.setTemperature(0); + requestBody.setStream(true); + + try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + // 创建HttpPost对象 + HttpPost httpPost = new HttpPost(API_Url); + //添加请求头 + httpPost.setHeader("Content-Type", "application/json"); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Authorization", "Bearer " + API_Key); + + requestBody.setStream(true); + // 设置请求体 + String jsonContent = JsonUtil.convertToJson(requestBody); + httpPost.setEntity(new StringEntity(jsonContent, ContentType.create("application/json", "UTF-8"))); + StringBuilder responseBody = new StringBuilder(); + try (CloseableHttpResponse response = httpClient.execute(httpPost); + BufferedReader reader = new BufferedReader(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) { + String line; + StringBuilder strContent = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if (line.startsWith("data: ")) { + String jsonData = line.substring(6); + if ("[DONE]".equals(jsonData)) { + //SseHelper.send(emitter, "[DONE]"); + emitter.send("[DONE]"); + // 更新历史会话记录 + ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel()); + HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message); + CosmosPatchOperations options = CosmosPatchOperations.create() + .replace("/updateTime", Instant.now().toEpochMilli()) + .add("/history/-", message); + chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options); + break; + } + JsonNode node = objectMapper.readTree(jsonData); + String content = node.path("choices") + .path(0) + .path("delta") + .path("content") + .asText(""); + if (!content.isEmpty()) { + responseBody.append(content); + strContent.append(content); + emitter.send(content); + } + } + } + emitter.complete(); + }catch (Exception e) + { + emitter.completeWithError(e); + } + }catch (Exception e) { + emitter.completeWithError(e); + } + + return emitter; + } + /*** * HttpClient 方式请求 * @param requestBody * @return */ - public static Map SendRequest(ChatRequestOKHttpDto requestBody) { + public static Map SendRequest(DeepSeekChatRequestDto requestBody) { Map mapper = new HashMap<>(); try (CloseableHttpClient httpClient = HttpClients.createDefault()) { // 创建HttpPost对象 @@ -175,8 +359,4 @@ public class DeepSeekClient { //TODO 请求接口 return mapper; } - - - - } diff --git a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java index 7e4ae5d..0a6797b 100644 --- a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java +++ b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java @@ -46,16 +46,17 @@ public class SparkChatRequestParam { public static class Message { private String role; private String content; + private String model; /** * 一个类下面有两种类型的对象,使用静态的对象产生方法,好思路 */ - public static Message ofUser(String content){ - return new Message("user",content); + public static Message ofUser(String content,String model){ + return new Message("user",content,model); } - public static Message ofAssistant(String content){ - return new Message("assistant",content); + public static Message ofAssistant(String content, String model){ + return new Message("assistant",content,model); } } diff --git a/src/main/java/cn/teammodel/controller/frontend/AiDeepSeekController.java b/src/main/java/cn/teammodel/controller/frontend/AiDeepSeekController.java index 697fd18..cffc411 100644 --- a/src/main/java/cn/teammodel/controller/frontend/AiDeepSeekController.java +++ b/src/main/java/cn/teammodel/controller/frontend/AiDeepSeekController.java @@ -1,7 +1,8 @@ package cn.teammodel.controller.frontend; import cn.teammodel.common.IdRequest; -import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto; +import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse; import cn.teammodel.model.dto.ai.deepseek.ChatReqDto; import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.model.entity.ai.DeepSeekSession; @@ -18,12 +19,15 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import javax.validation.Valid; import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; @RestController @RequestMapping("/aiDeepSeek") @Api(tags = "AI DeepSeek 能力") public class AiDeepSeekController { + private final ExecutorService executorService = Executors.newCachedThreadPool(); /** * 访问DeepSeek方法 */ @@ -118,19 +122,29 @@ public class AiDeepSeekController { * @return */ @PostMapping("chat") - @ApiOperation("与deepseek的对话") - public R ChatCompletion(@RequestBody @Valid ChatReqDto messageDto) { - ChatResponseDto chatResponse = deepSeekChatService.ChatAsk(messageDto); + @ApiOperation("单独 与deepseek的对话") + public R ChatCompletion(@RequestBody @Valid ChatReqDto messageDto) { + DeepSeekChatResponse chatResponse = deepSeekChatService.ChatAsk(messageDto); return R.success(chatResponse); } + @PostMapping("okhttp/emitter") + @ApiOperation("单 与deepseek的对话") + public SseEmitter ChatEmiter(@RequestBody @Valid ChatReqDto messageDto) { + return deepSeekChatService.OKHttpChatSeeEmitterAsk(messageDto); + } + + /** + * 与deepseek的对话 并保存到数据库中 + * @param chatCompletionReqDto + * @return + */ @PostMapping("chat/completion") @ApiOperation("与 spark 的流式对话") - public SseEmitter chatCompletion(@RequestBody @Valid ChatReqDto messageDto) { + public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) { String userId = SecurityUtil.getLoginUser().getId(); - SseEmitter sseEmitter = new SseEmitter(); - return sseEmitter; - //return deepSeekChatService.ChatSeeEmitterAsk(messageDto, userId); + SseEmitter sseEmitter = new SseEmitter(-1L); + return deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto); } diff --git a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java index 86a9466..076c2b6 100644 --- a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java +++ b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java @@ -13,6 +13,10 @@ public class ChatCompletionReqDto { */ @ApiModelProperty("会话id,没有则为空") private String appId; + + @ApiModelProperty("模型") + private String model = "SparkMax"; + @NotBlank(message = "请输入消息内容") private String text; } \ No newline at end of file diff --git a/src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatResponseDto.java b/src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatResponseDto.java deleted file mode 100644 index 82ffd8f..0000000 --- a/src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatResponseDto.java +++ /dev/null @@ -1,66 +0,0 @@ -package cn.teammodel.model.dto.ai.deepseek; - -import io.swagger.annotations.ApiModelProperty; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Builder.Default; -import lombok.Data; - -import java.util.List; - -/** - * 接收DeepSeek响应的json 格式来设计响应体 用来响应体字符串 - */ -@Data -public class ChatResponseDto { - private int code; - private String msg; - private long wasteTime; - private String id; - private String object; - private long created; - private String model; - - private Usage usage; - - /** - * 返回内容 - */ - private List choices; - private String system_fingerprint; - @Data - @AllArgsConstructor - public static class Choice { - private int index; - private DeepSeekMessage message; - private String logprobs; - private String finish_reason; - } - @Data - public static class Usage{ - private int prompt_tokens; - private int completion_tokens; - private int total_tokens; - private Prompt_Tokens_Details prompt_tokens_details; - private int prompt_cache_hit_tokens; - private int prompt_cache_miss_tokens; - } - @Data - public static class Prompt_Tokens_Details { - private int cached_tokens; - } - - @Data - public static class DeepSeekMessage{ - - /** - * 角色 - */ - private String role; - - /** - * 提问内容 - */ - private String content; - } -} diff --git a/src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatRequestOKHttpDto.java b/src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatRequestDto.java similarity index 89% rename from src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatRequestOKHttpDto.java rename to src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatRequestDto.java index 20b845f..1309682 100644 --- a/src/main/java/cn/teammodel/model/dto/ai/deepseek/ChatRequestOKHttpDto.java +++ b/src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatRequestDto.java @@ -37,7 +37,7 @@ import java.util.List; * } */ @Data -public class ChatRequestOKHttpDto { +public class DeepSeekChatRequestDto { @ApiModelProperty("会话模型") private String model; @ApiModelProperty("会话内容") @@ -53,4 +53,11 @@ public class ChatRequestOKHttpDto { */ @ApiModelProperty("最大生成token数") private int max_tokens = 2048; + + @ApiModelProperty("是否流式输出") + private boolean stream = false; + + public boolean getStream() { + return stream; + } } diff --git a/src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatResponse.java b/src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatResponse.java new file mode 100644 index 0000000..628c34b --- /dev/null +++ b/src/main/java/cn/teammodel/model/dto/ai/deepseek/DeepSeekChatResponse.java @@ -0,0 +1,114 @@ +package cn.teammodel.model.dto.ai.deepseek; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.util.List; + +/** + * 接收DeepSeek响应的json 格式来设计响应体 用来响应体字符串 + */ +@Data +public class DeepSeekChatResponse { + /** + * 响应码 + */ + private int code; + /** + * 响应内容 + */ + private String msg; + /** + * 思考项目 + */ + private long wasteTime; + /** + * 会话id + */ + private String id; + /** + * 对象的类型, 其值为 chat.completion + */ + private String object; + + /** + * 创建聊天完成时的 Unix 时间戳(以秒为单位)。 + */ + private long created; + /** + * 生成该 completion 的模型名。 + */ + private String model; + + /** + * 返回内容 + */ + private List choices; + + /** + * 该对话补全请求的用量信息。 + */ + private Usage usage; + + private String system_fingerprint; + + /** + * 模型生成的 completion 的选择列表。 + */ + @Data + @AllArgsConstructor + public static class Choice { + /** + * 该 completion 在模型生成的 completion 的选择列表中的索引。 + */ + private int index; + //内容 + private DeepSeekMessage message; + //private String logprobs; + private String finish_reason; + } + + /** + * 该对话补全请求的用量信息。 + */ + @Data + public static class Usage{ + //用户 prompt 所包含的 token 数。该值等于 prompt_cache_hit_tokens + prompt_cache_miss_tokens + private int prompt_tokens; + //模型 completion 产生的 token 数。 + private int completion_tokens; + //该请求中,所有 token 的数量(prompt + completion)。 + private int total_tokens; + private Prompt_Tokens_Details prompt_tokens_details; + //用户 prompt 中,命中上下文缓存的 token 数。 + private int prompt_cache_hit_tokens; + //用户 prompt 中,未命中上下文缓存的 token 数。 + private int prompt_cache_miss_tokens; + } + + /** + * completion tokens 的详细信息。 + */ + @Data + public static class Prompt_Tokens_Details { + //推理模型所产生的思维链 token 数量 + private int cached_tokens; + } + + /** + * 聊天内容 + */ + @Data + public static class DeepSeekMessage{ + + /** + * 角色 + */ + private String role; + + /** + * 该 completion 的内容。 + */ + private String content; + } +} diff --git a/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java index fcda6be..f78e55f 100644 --- a/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java +++ b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java @@ -45,13 +45,16 @@ public class ChatSession extends BaseItem { private Integer cost; private Long createTime; - public static Message of(String userText, String gptText) { + public String model; + + public static Message of(String userText, String gptText,String model) { Message message = new Message(); message.setId(UUID.randomUUID().toString()); message.setCost(0); message.setUserText(userText); message.setGptText(gptText); message.setCreateTime(Instant.now().toEpochMilli()); + message.setModel(model); return message; } } diff --git a/src/main/java/cn/teammodel/service/DeepSeekService.java b/src/main/java/cn/teammodel/service/DeepSeekService.java index 5d53c70..13b95bb 100644 --- a/src/main/java/cn/teammodel/service/DeepSeekService.java +++ b/src/main/java/cn/teammodel/service/DeepSeekService.java @@ -1,18 +1,34 @@ package cn.teammodel.service; -import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto; +import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse; import cn.teammodel.model.dto.ai.deepseek.ChatReqDto; -import reactor.core.publisher.Flux; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; /** * 访问DeepSeek方法 */ public interface DeepSeekService { + /** + * 单独 获取AI的回答 + * @param message + * @return + */ + DeepSeekChatResponse ChatAsk(ChatReqDto message); + /** * 获取AI的回答 * @param message * @return */ - ChatResponseDto ChatAsk(ChatReqDto message); + SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message); + + /** + * 获取AI的回答 + * @param chatCompletionReqDto + * @return + */ + SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto); + } diff --git a/src/main/java/cn/teammodel/service/DeepSeekSessionService.java b/src/main/java/cn/teammodel/service/DeepSeekSessionService.java index 90ef875..27e01b0 100644 --- a/src/main/java/cn/teammodel/service/DeepSeekSessionService.java +++ b/src/main/java/cn/teammodel/service/DeepSeekSessionService.java @@ -1,6 +1,5 @@ package cn.teammodel.service; -import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto; import cn.teammodel.model.entity.ai.DeepSeekSession; import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage; diff --git a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java index ec3a9f0..416e295 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java @@ -4,6 +4,7 @@ import cn.teammodel.ai.JsonLoader; import cn.teammodel.ai.SparkGptClient; import cn.teammodel.ai.SseHelper; import cn.teammodel.ai.cache.HistoryCache; +import cn.teammodel.ai.deepseek.DeepSeekClient; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import cn.teammodel.common.ErrorCode; @@ -18,6 +19,7 @@ import cn.teammodel.model.entity.User; import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatMessageService; +import cn.teammodel.service.DeepSeekService; import cn.teammodel.utils.RepositoryUtil; import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.TypeReference; @@ -49,12 +51,20 @@ public class ChatMessageServiceImpl implements ChatMessageService { @Resource private JsonLoader jsonLoader; + /** + * 访问DeepSeek方法 + */ + @Resource + private DeepSeekService deepSeekChatService; + + @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) { // 目前仅使用讯飞星火大模型 String appId = chatCompletionReqDto.getAppId(); SseEmitter sseEmitter; - if (StringUtils.isEmpty(appId)) { + // + if (StringUtils.isEmpty(appId) || chatCompletionReqDto.getModel().equals("DeepSeek_Chat")) { sseEmitter = completionBySession(chatCompletionReqDto, userId); } else { sseEmitter = completionByApp(chatCompletionReqDto, false); @@ -165,37 +175,57 @@ public class ChatMessageServiceImpl implements ChatMessageService { String appPrompt = chatApp.getPrompt(); SseEmitter sseEmitter = new SseEmitter(-1L); - SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); - // open 回调 - listener.setOnOpen((s) -> { - // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) - log.info("callback: ws open event emmit"); - }); - // 对话完成的回调 - listener.setOnComplete((s) -> { - log.info("callback: ws complete event emmit"); - SseHelper.send(sseEmitter, "[DONE]"); - // 处理完成后的事件: - if (!justApi) { - // 保存消息记录, 缓存更改 + + switch (chatCompletionReqDto.getModel()) { + //星火大模型 + case "SparkMax": + { + SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); + // open 回调 + listener.setOnOpen((s) -> { + // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) + log.info("callback: ws open event emmit"); + }); + // 对话完成的回调 + SseEmitter finalSseEmitter = sseEmitter; + listener.setOnComplete((s) -> { + log.info("callback: ws complete event emmit"); + SseHelper.send(finalSseEmitter, "[DONE]"); + // 处理完成后的事件: + if (!justApi) { + // 保存消息记录, 缓存更改 + } + }); + // 错误的回调 + listener.setOnError((s) -> { + log.error("callback: ws error, info: " + s); + // 返还积分 + }); + List messageList = new ArrayList<>(); + messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt,"SparkMax")); + messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt,"SparkMax")); + SparkChatRequestParam requestParam = SparkChatRequestParam + .builder() + .uid(userId) + .chatId(appId) + .messageList(messageList) + .build(); + sparkGptClient.streamChatCompletion(requestParam, listener); + return finalSseEmitter; } - }); - // 错误的回调 - listener.setOnError((s) -> { - log.error("callback: ws error, info: " + s); - // 返还积分 - }); - List messageList = new ArrayList<>(); - messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt)); - messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt)); - SparkChatRequestParam requestParam = SparkChatRequestParam - .builder() - .uid(userId) - .chatId(appId) - .messageList(messageList) - .build(); - sparkGptClient.streamChatCompletion(requestParam, listener); - return sseEmitter; + // DeepSeek 模型 + case "DeepSeek_Chat": + { + // OKHttp 方式请求 + sseEmitter = deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto); + //HttpClient 方式请求 + //sseEmitter = DeepSeekClient.HttpClientSendRequests(chatCompletionReqDto); + + return sseEmitter; + } + default: + throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "不支持的模型"); + } } /** @@ -211,38 +241,53 @@ public class ChatMessageServiceImpl implements ChatMessageService { } SseEmitter sseEmitter = new SseEmitter(-1L); - SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); - // open 回调 - listener.setOnOpen((s) -> { - // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) - log.info("callback: ws open event emmit"); - }); - // 对话完成的回调 - listener.setOnComplete((s) -> { - log.info("callback: ws complete event emmit"); - SseHelper.send(sseEmitter, "[DONE]"); - // 处理完成后的事件: 保存消息记录, 缓存更改 - ChatSession.Message message = ChatSession.Message.of(userPrompt, s); - HistoryCache.updateContext(sessionId, message); - CosmosPatchOperations options = CosmosPatchOperations.create() - .replace("/updateTime", Instant.now().toEpochMilli()) - .add("/history/-", message); - chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); - }); - // 错误的回调 - listener.setOnError((s) -> { - log.error("callback: ws error, info: " + s); - // 返还积分 - }); - List messageList = fetchContext(sessionId, userPrompt); - SparkChatRequestParam requestParam = SparkChatRequestParam - .builder() - .uid(userId) - .chatId(sessionId) - .messageList(messageList) - .build(); - sparkGptClient.streamChatCompletion(requestParam, listener); - return sseEmitter; + switch (chatCompletionReqDto.getModel()){ + // 星火大模型 + case "SparkMax":{ + SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); + // open 回调 + listener.setOnOpen((s) -> { + // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) + log.info("callback: ws open event emmit"); + }); + // 对话完成的回调 + SseEmitter finalSseEmitter = sseEmitter; + listener.setOnComplete((s) -> { + log.info("callback: ws complete event emmit"); + SseHelper.send(finalSseEmitter, "[DONE]"); + // 处理完成后的事件: 保存消息记录, 缓存更改 + ChatSession.Message message = ChatSession.Message.of(userPrompt, s,chatCompletionReqDto.getModel()); + HistoryCache.updateContext(sessionId, message); + CosmosPatchOperations options = CosmosPatchOperations.create() + .replace("/updateTime", Instant.now().toEpochMilli()) + .add("/history/-", message); + chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options); + }); + // 错误的回调 + listener.setOnError((s) -> { + log.error("callback: ws error, info: " + s); + // 返还积分 + }); + List messageList = fetchContext(sessionId, userPrompt,chatCompletionReqDto.getModel()); + SparkChatRequestParam requestParam = SparkChatRequestParam + .builder() + .uid(userId) + .chatId(sessionId) + .messageList(messageList) + .build(); + sparkGptClient.streamChatCompletion(requestParam, listener); + return finalSseEmitter; + } + // DeepSeek 模型 + case "DeepSeek_Chat": + { + sseEmitter = deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto ); + return sseEmitter; + } + default:{ + throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "不支持的模型"); + } + } } /** @@ -258,7 +303,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { if (sessions.size() == 0) { // 初始化欢迎语 ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," + - "你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + "你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax"); List history = Collections.singletonList(message); session = new ChatSession(); session.setId(sessionId); @@ -285,7 +330,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录, 缓存更改 - ChatSession.Message message = ChatSession.Message.of(userPrompt, s); + ChatSession.Message message = ChatSession.Message.of(userPrompt, s,"SparkMax"); HistoryCache.updateContext(sessionId, message); CosmosPatchOperations options = CosmosPatchOperations.create() .replace("/updateTime", Instant.now().toEpochMilli()) @@ -297,7 +342,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { log.error("callback: ws error, info: " + s); // 返还积分 }); - List messageList = fetchContext(userId, userPrompt); + List messageList = fetchContext(userId, userPrompt,"SparkMax"); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) @@ -324,7 +369,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { if (sessions.size() == 0) { // 初始化欢迎语 ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," + - "你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + "你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax"); List history = Collections.singletonList(message); session = new ChatSession(); session.setId(sessionId); @@ -351,7 +396,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { log.info("callback: ws complete event emmit"); SseHelper.send(sseEmitter, "[DONE]"); // 处理完成后的事件: 保存消息记录, 缓存更改 - ChatSession.Message message = ChatSession.Message.of(userPrompt, s); + ChatSession.Message message = ChatSession.Message.of(userPrompt, s,"SparkMax"); HistoryCache.updateContext(sessionId, message); CosmosPatchOperations options = CosmosPatchOperations.create() .replace("/updateTime", Instant.now().toEpochMilli()) @@ -363,7 +408,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { log.error("callback: ws error, info: " + s); // 返还积分 }); - List messageList = fetchContext(userId, userPrompt); + List messageList = fetchContext(userId, userPrompt,"SparkMax"); SparkChatRequestParam requestParam = SparkChatRequestParam .builder() .uid(userId) @@ -374,7 +419,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { return sseEmitter; } - List fetchContext(String userId, String prompt) { + List fetchContext(String userId, String prompt, String model) { List context = HistoryCache.getContext(userId); List paramMessages = new ArrayList<>(); // 暂未缓存,从数据库拉取 @@ -388,10 +433,10 @@ public class ChatMessageServiceImpl implements ChatMessageService { // convert DB Message to Spark Message context.forEach(item -> { - paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText())); - paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText())); + paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText(),model)); + paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText(),model)); }); - paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt)); + paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt,model)); return paramMessages; } diff --git a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java index cd37643..2a911f2 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java @@ -35,7 +35,7 @@ public class ChatSessionServiceImpl implements ChatSessionService { @Override public String createSession(String userId, String name) { // 初始化欢迎语 - Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax"); List history = Collections.singletonList(message); ChatSession chatSession = new ChatSession(); chatSession.setId(UUID.randomUUID().toString()); diff --git a/src/main/java/cn/teammodel/service/impl/DeepSeekServiceImpl.java b/src/main/java/cn/teammodel/service/impl/DeepSeekServiceImpl.java index 4d51c17..cb77f88 100644 --- a/src/main/java/cn/teammodel/service/impl/DeepSeekServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/DeepSeekServiceImpl.java @@ -1,47 +1,72 @@ package cn.teammodel.service.impl; +import cn.teammodel.ai.cache.HistoryCache; import cn.teammodel.ai.deepseek.DeepSeekClient; import cn.teammodel.common.PK; -import cn.teammodel.model.dto.ai.deepseek.ChatRequestOKHttpDto; -import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto; +import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto; +import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse; import cn.teammodel.model.dto.ai.deepseek.ChatReqDto; +import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.model.entity.ai.DeepSeekSession; import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage; +import cn.teammodel.repository.ChatSessionRepository; import cn.teammodel.repository.DeepSeekRepository; import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.DeepSeekService; import cn.teammodel.service.DeepSeekSessionService; +import cn.teammodel.utils.RepositoryUtil; +import com.azure.cosmos.models.CosmosPatchOperations; +import com.fasterxml.jackson.databind.JsonNode; +import lombok.extern.slf4j.Slf4j; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; import org.springframework.stereotype.Service; - +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * 描述:访问DeepSeek方法 */ @Service +@Slf4j public class DeepSeekServiceImpl implements DeepSeekService { @Resource private DeepSeekSessionService deepSeekService; @Resource private DeepSeekRepository deepSeekRepository; + @Resource + private ChatSessionRepository chatSessionRepository; + private final ExecutorService executorService = Executors.newCachedThreadPool(); + private final ObjectMapper objectMapper = new ObjectMapper(); /** * 提问 * @param message * @return */ @Override - public ChatResponseDto ChatAsk(ChatReqDto message) { + public DeepSeekChatResponse ChatAsk(ChatReqDto message) { //创建消息列表 List msg = new ArrayList<>(); msg.add(message); //构建请求头 - ChatRequestOKHttpDto requestBody = new ChatRequestOKHttpDto(); + DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto(); requestBody.setModel(DeepSeekClient.API_Model); requestBody.setMessages(msg); requestBody.setTemperature(0); @@ -50,7 +75,7 @@ public class DeepSeekServiceImpl implements DeepSeekService { //开始时间 long startTime = System.currentTimeMillis(); //发起请求 - ChatResponseDto response = DeepSeekClient.SendRequests(requestBody); + DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody); //Map response = DeepSeekClient.SendRequests(requestBody); //Map response = SendRequest(requestBody); //结束时间 @@ -77,6 +102,108 @@ public class DeepSeekServiceImpl implements DeepSeekService { return response; } + + @Override + public SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message) { + SseEmitter se = new SseEmitter(-1L); + //创建消息列表 + List msg = new ArrayList<>(); + msg.add(message); + + //构建请求头 + DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto(); + requestBody.setModel(DeepSeekClient.API_Model); + requestBody.setMessages(msg); + requestBody.setTemperature(0); + requestBody.setMax_tokens(1024); + requestBody.setStream(true); + if (requestBody.getStream()){ + se = DeepSeekClient.SendRequestsEmitter(requestBody); + + }else { + + DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody); + } + return se; + } + + /** + * 提问 流式回答 + * @param chatCompletionReqDto + * @return + */ + @Override + public SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto) { + SseEmitter sseEmitter = new SseEmitter(-1L); + StringBuilder strContent = new StringBuilder(); + executorService.execute(()-> { + try { + log.info("流式回答开始,问题:{}", chatCompletionReqDto.getText()); + try (CloseableHttpClient client = HttpClients.createDefault()) { + HttpPost httpPost = new HttpPost(DeepSeekClient.API_Url); + httpPost.setHeader("Content-Type", "application/json"); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Authorization", "Bearer " + DeepSeekClient.API_Key); + + Map question = new HashMap<>(); + question.put("role", "user"); + question.put("content", chatCompletionReqDto.getText()); + + Map requestMap = new HashMap<>(); + requestMap.put("model", DeepSeekClient.API_Model); + requestMap.put("messages", Collections.singletonList(question)); + requestMap.put("stream", true); + + String requestBody = objectMapper.writeValueAsString(requestMap); + httpPost.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8)); + StringBuilder responseBody = new StringBuilder(); + try (CloseableHttpResponse response = client.execute(httpPost); + BufferedReader reader = new BufferedReader( + new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (line.startsWith("data: ")) { + String jsonData = line.substring(6); + if ("[DONE]".equals(jsonData)) { + sseEmitter.send("[DONE]"); + // 会话完成,更新历史会话记录 + ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel()); + HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message); + CosmosPatchOperations options = CosmosPatchOperations.create() + .replace("/updateTime", Instant.now().toEpochMilli()) + .add("/history/-", message); + chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options); + break; + } + JsonNode node = objectMapper.readTree(jsonData); + String content = node.path("choices") + .path(0) + .path("delta") + .path("content") + .asText(""); + if (!content.isEmpty()) { + responseBody.append(content); + strContent.append(content); + sseEmitter.send(content); + } + } + } + log.info("流式回答结束,{}",question); + sseEmitter.complete(); + } + } catch (Exception e) { + log.error("处理 Deepseek 请求时发生错误", e); + sseEmitter.completeWithError(e); + } + } catch (Exception e) { + log.error("处理 Deepseek 请求时发生错误", e); + sseEmitter.completeWithError(e); + } + }); + + return sseEmitter; + } + //region 辅助方法 /** * 新增/更新会话 @@ -85,7 +212,7 @@ public class DeepSeekServiceImpl implements DeepSeekService { * @param savaMessage * @param response */ - private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, ChatResponseDto response) { + private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, DeepSeekChatResponse response) { if (session.getId() == null){ List history = Collections.singletonList(savaMessage); String userId = SecurityUtil.getLoginUser().getId();