up:修改DeeSeekAI对话 和会话记录保存在和星火大模型中一起的

develop
PL 1 month ago
parent c6c301c532
commit 3568abad10

@ -6,10 +6,9 @@ import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.*;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
@ -69,7 +68,6 @@ public class SparkGptClient implements InitializingBean {
e.printStackTrace();
}
}
/**
* URL
*/
@ -111,6 +109,7 @@ public class SparkGptClient implements InitializingBean {
return httpUrl.toString();
}
@Override
public void afterPropertiesSet() throws Exception {
init();

@ -1,16 +1,29 @@
package cn.teammodel.ai.deepseek;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.model.dto.ai.deepseek.ChatRequestOKHttpDto;
import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.utils.JsonUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.google.gson.Gson;
import com.sun.org.apache.bcel.internal.generic.NEW;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okio.Buffer;
import okio.BufferedSource;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
@ -19,20 +32,35 @@ import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Slf4j
public class DeepSeekClient {
private static final String API_Key;
private static final String API_Url;
public class DeepSeekClient {
public static final String API_Key;
public static final String API_Url;
public static String API_Model;
@Resource
private static ChatSessionRepository chatSessionRepository;
private static final ExecutorService executorService = Executors.newCachedThreadPool();
private static final ObjectMapper objectMapper = new ObjectMapper();
/**
* key url
*/
@ -61,7 +89,7 @@ public class DeepSeekClient {
msg.add(mssage);
//构建请求头
ChatRequestOKHttpDto requestBody = new ChatRequestOKHttpDto();
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
@ -69,7 +97,7 @@ public class DeepSeekClient {
long startTime = System.currentTimeMillis();
//发起请求
ChatResponseDto response = SendRequests(requestBody);
DeepSeekChatResponse response = SendRequests(requestBody);
//Map<String, Object> response = SendRequest(requestBody);
Long endTime = System.currentTimeMillis();
//思考耗时 秒
@ -85,9 +113,9 @@ public class DeepSeekClient {
* @param requestBody
* @return
*/
public static ChatResponseDto SendRequests(ChatRequestOKHttpDto requestBody)
public static DeepSeekChatResponse SendRequests(DeepSeekChatRequestDto requestBody)
{
ChatResponseDto chatResponse = new ChatResponseDto();
DeepSeekChatResponse chatResponse = new DeepSeekChatResponse();
//OkHttpClient client = new OkHttpClient().newBuilder().connectTimeout(60, TimeUnit.SECONDS).build();//设置连接超时时间 1分钟
OkHttpClient client = new OkHttpClient().newBuilder().build();//设置连接超时时间 1分钟
@ -104,13 +132,36 @@ public class DeepSeekClient {
.addHeader("Accept", "application/json")
.addHeader("Authorization", "Bearer "+API_Key)
.build();
ObjectMapper objectMapper = new ObjectMapper();
try(Response response = client.newCall(request).execute()) {
if (response.isSuccessful() && response.body() != null) {
String responseBody = response.body().string();
StringBuilder responseBody = new StringBuilder();
try (BufferedSource source = response.body().source()) {
Buffer buffer = new Buffer();
while (source.read(buffer, 2048) != -1) {
// 处理流式数据
String chunk = buffer.readUtf8();
if (chunk.startsWith("data:") && !chunk.contains("data: [DONE]")) {
String[] split = chunk.split("data:");
for (String result : split) {
if (StringUtils.hasLength(result) && StringUtils.hasLength(result.trim())) {
JsonNode jsonNode = objectMapper.readTree(result);
if (!ObjectUtils.isEmpty(jsonNode.get("choices"))) {
JsonNode delta = jsonNode.get("choices").get(0).get("delta");
log.debug("Delta Content: {}", delta.get("content").asText());
responseBody.append(delta.get("content").asText());
}
}
}
}
}
}
String responseBody1 = response.body().string();
// 使用 Gson 将 JSON 字符串转换为 MyEntity 对象
Gson gson = new Gson();
chatResponse = gson.fromJson(responseBody, ChatResponseDto.class);
chatResponse = gson.fromJson(responseBody1, DeepSeekChatResponse.class);
// 确保关闭响应体以释放资源
response.body().close();
chatResponse.setCode(200);
@ -132,12 +183,145 @@ public class DeepSeekClient {
return chatResponse;
}
/**
* OkHttpClient
* @param requestBody
* @return
*/
public static SseEmitter SendRequestsEmitter(DeepSeekChatRequestDto requestBody)
{
SseEmitter sseEmitter = new SseEmitter(-1L);
//OkHttpClient client = new OkHttpClient().newBuilder().connectTimeout(60, TimeUnit.SECONDS).build();//设置连接超时时间 1分钟
OkHttpClient client = new OkHttpClient().newBuilder().build();//
MediaType mediaType = MediaType.parse("application/json");
//String content = "{\n \"messages\": [\n {\n \"content\": \"You are a helpful assistant\",\n \"role\": \"system\"\n },\n {\n \"content\": \"Hi\",\n \"role\": \"user\"\n }\n ],\n \"model\": \"deepseek-chat\",\n \"frequency_penalty\": 0,\n \"max_tokens\": 2048,\n \"presence_penalty\": 0,\n \"response_format\": {\n \"type\": \"text\"\n },\n \"stop\": null,\n \"stream\": false,\n \"stream_options\": null,\n \"temperature\": 1,\n \"top_p\": 1,\n \"tools\": null,\n \"tool_choice\": \"none\",\n \"logprobs\": false,\n \"top_logprobs\": null\n}";
String content = JsonUtil.convertToJson(requestBody);
RequestBody body = RequestBody.create(mediaType, content);
Request request = new Request.Builder()
.url(API_Url)
.method("POST", body)
.addHeader("Content-Type", "application/json")
.addHeader("Accept", "application/json")
.addHeader("Authorization", "Bearer "+API_Key)
.build();
ObjectMapper objectMapper = new ObjectMapper();
try(Response response = client.newCall(request).execute()) {
if (response.isSuccessful() && response.body() != null) {
StringBuilder responseBody = new StringBuilder();
try (BufferedSource source = response.body().source()) {
Buffer buffer = new Buffer();
while (source.read(buffer, 2048) != -1) {
// 处理流式数据
String chunk = buffer.readUtf8();
if (chunk.startsWith("data:") && !chunk.contains("data: [DONE]")) {
String[] split = chunk.split("data:");
for (String result : split) {
if (StringUtils.hasLength(result) && StringUtils.hasLength(result.trim())) {
JsonNode jsonNode = objectMapper.readTree(result);
if (!ObjectUtils.isEmpty(jsonNode.get("choices"))) {
JsonNode delta = jsonNode.get("choices").get(0).get("delta");
log.debug("Delta Content: {}", delta.get("content").asText());
sseEmitter.send(delta);
}
}
}
}
}
}catch (IOException e) {
sseEmitter.completeWithError(e);
}
} else {
sseEmitter.completeWithError(new Exception("请求DeepSeek服务器失败"));
}
} catch (IOException e) {
sseEmitter.completeWithError(e);
}
return sseEmitter;
}
/**
* HttpClient
* @param chatCompletionReqDto
* @return
*/
public static SseEmitter HttpClientSendRequests(ChatCompletionReqDto chatCompletionReqDto){
SseEmitter emitter = new SseEmitter(-1L);
List<ChatReqDto> msg = new ArrayList<>();
msg.add(new ChatReqDto(chatCompletionReqDto.getSessionId(), "user", chatCompletionReqDto.getText()));
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setStream(true);
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
// 创建HttpPost对象
HttpPost httpPost = new HttpPost(API_Url);
//添加请求头
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Authorization", "Bearer " + API_Key);
requestBody.setStream(true);
// 设置请求体
String jsonContent = JsonUtil.convertToJson(requestBody);
httpPost.setEntity(new StringEntity(jsonContent, ContentType.create("application/json", "UTF-8")));
StringBuilder responseBody = new StringBuilder();
try (CloseableHttpResponse response = httpClient.execute(httpPost);
BufferedReader reader = new BufferedReader(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) {
String line;
StringBuilder strContent = new StringBuilder();
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String jsonData = line.substring(6);
if ("[DONE]".equals(jsonData)) {
//SseHelper.send(emitter, "[DONE]");
emitter.send("[DONE]");
// 更新历史会话记录
ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel());
HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options);
break;
}
JsonNode node = objectMapper.readTree(jsonData);
String content = node.path("choices")
.path(0)
.path("delta")
.path("content")
.asText("");
if (!content.isEmpty()) {
responseBody.append(content);
strContent.append(content);
emitter.send(content);
}
}
}
emitter.complete();
}catch (Exception e)
{
emitter.completeWithError(e);
}
}catch (Exception e) {
emitter.completeWithError(e);
}
return emitter;
}
/***
* HttpClient
* @param requestBody
* @return
*/
public static Map<String, Object> SendRequest(ChatRequestOKHttpDto requestBody) {
public static Map<String, Object> SendRequest(DeepSeekChatRequestDto requestBody) {
Map<String, Object> mapper = new HashMap<>();
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
// 创建HttpPost对象
@ -175,8 +359,4 @@ public class DeepSeekClient {
//TODO 请求接口
return mapper;
}
}

@ -46,16 +46,17 @@ public class SparkChatRequestParam {
public static class Message {
private String role;
private String content;
private String model;
/**
* ,使,
*/
public static Message ofUser(String content){
return new Message("user",content);
public static Message ofUser(String content,String model){
return new Message("user",content,model);
}
public static Message ofAssistant(String content){
return new Message("assistant",content);
public static Message ofAssistant(String content, String model){
return new Message("assistant",content,model);
}
}

@ -1,7 +1,8 @@
package cn.teammodel.controller.frontend;
import cn.teammodel.common.IdRequest;
import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.model.entity.ai.DeepSeekSession;
@ -18,12 +19,15 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.validation.Valid;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@RestController
@RequestMapping("/aiDeepSeek")
@Api(tags = "AI DeepSeek 能力")
public class AiDeepSeekController {
private final ExecutorService executorService = Executors.newCachedThreadPool();
/**
* 访DeepSeek
*/
@ -118,19 +122,29 @@ public class AiDeepSeekController {
* @return
*/
@PostMapping("chat")
@ApiOperation("与deepseek的对话")
public R<ChatResponseDto> ChatCompletion(@RequestBody @Valid ChatReqDto messageDto) {
ChatResponseDto chatResponse = deepSeekChatService.ChatAsk(messageDto);
@ApiOperation("单独 与deepseek的对话")
public R<DeepSeekChatResponse> ChatCompletion(@RequestBody @Valid ChatReqDto messageDto) {
DeepSeekChatResponse chatResponse = deepSeekChatService.ChatAsk(messageDto);
return R.success(chatResponse);
}
@PostMapping("okhttp/emitter")
@ApiOperation("单 与deepseek的对话")
public SseEmitter ChatEmiter(@RequestBody @Valid ChatReqDto messageDto) {
return deepSeekChatService.OKHttpChatSeeEmitterAsk(messageDto);
}
/**
* deepseek
* @param chatCompletionReqDto
* @return
*/
@PostMapping("chat/completion")
@ApiOperation("与 spark 的流式对话")
public SseEmitter chatCompletion(@RequestBody @Valid ChatReqDto messageDto) {
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
String userId = SecurityUtil.getLoginUser().getId();
SseEmitter sseEmitter = new SseEmitter();
return sseEmitter;
//return deepSeekChatService.ChatSeeEmitterAsk(messageDto, userId);
SseEmitter sseEmitter = new SseEmitter(-1L);
return deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto);
}

@ -13,6 +13,10 @@ public class ChatCompletionReqDto {
*/
@ApiModelProperty("会话id没有则为空")
private String appId;
@ApiModelProperty("模型")
private String model = "SparkMax";
@NotBlank(message = "请输入消息内容")
private String text;
}

@ -1,66 +0,0 @@
package cn.teammodel.model.dto.ai.deepseek;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Builder.Default;
import lombok.Data;
import java.util.List;
/**
* DeepSeekjson
*/
@Data
public class ChatResponseDto {
private int code;
private String msg;
private long wasteTime;
private String id;
private String object;
private long created;
private String model;
private Usage usage;
/**
*
*/
private List<Choice> choices;
private String system_fingerprint;
@Data
@AllArgsConstructor
public static class Choice {
private int index;
private DeepSeekMessage message;
private String logprobs;
private String finish_reason;
}
@Data
public static class Usage{
private int prompt_tokens;
private int completion_tokens;
private int total_tokens;
private Prompt_Tokens_Details prompt_tokens_details;
private int prompt_cache_hit_tokens;
private int prompt_cache_miss_tokens;
}
@Data
public static class Prompt_Tokens_Details {
private int cached_tokens;
}
@Data
public static class DeepSeekMessage{
/**
*
*/
private String role;
/**
*
*/
private String content;
}
}

@ -37,7 +37,7 @@ import java.util.List;
* }
*/
@Data
public class ChatRequestOKHttpDto {
public class DeepSeekChatRequestDto {
@ApiModelProperty("会话模型")
private String model;
@ApiModelProperty("会话内容")
@ -53,4 +53,11 @@ public class ChatRequestOKHttpDto {
*/
@ApiModelProperty("最大生成token数")
private int max_tokens = 2048;
@ApiModelProperty("是否流式输出")
private boolean stream = false;
public boolean getStream() {
return stream;
}
}

@ -0,0 +1,114 @@
package cn.teammodel.model.dto.ai.deepseek;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.List;
/**
* DeepSeekjson
*/
@Data
public class DeepSeekChatResponse {
/**
*
*/
private int code;
/**
*
*/
private String msg;
/**
*
*/
private long wasteTime;
/**
* id
*/
private String id;
/**
* , chat.completion
*/
private String object;
/**
* Unix
*/
private long created;
/**
* completion
*/
private String model;
/**
*
*/
private List<Choice> choices;
/**
*
*/
private Usage usage;
private String system_fingerprint;
/**
* completion
*/
@Data
@AllArgsConstructor
public static class Choice {
/**
* completion completion
*/
private int index;
//内容
private DeepSeekMessage message;
//private String logprobs;
private String finish_reason;
}
/**
*
*/
@Data
public static class Usage{
//用户 prompt 所包含的 token 数。该值等于 prompt_cache_hit_tokens + prompt_cache_miss_tokens
private int prompt_tokens;
//模型 completion 产生的 token 数。
private int completion_tokens;
//该请求中,所有 token 的数量prompt + completion
private int total_tokens;
private Prompt_Tokens_Details prompt_tokens_details;
//用户 prompt 中,命中上下文缓存的 token 数。
private int prompt_cache_hit_tokens;
//用户 prompt 中,未命中上下文缓存的 token 数。
private int prompt_cache_miss_tokens;
}
/**
* completion tokens
*/
@Data
public static class Prompt_Tokens_Details {
//推理模型所产生的思维链 token 数量
private int cached_tokens;
}
/**
*
*/
@Data
public static class DeepSeekMessage{
/**
*
*/
private String role;
/**
* completion
*/
private String content;
}
}

@ -45,13 +45,16 @@ public class ChatSession extends BaseItem {
private Integer cost;
private Long createTime;
public static Message of(String userText, String gptText) {
public String model;
public static Message of(String userText, String gptText,String model) {
Message message = new Message();
message.setId(UUID.randomUUID().toString());
message.setCost(0);
message.setUserText(userText);
message.setGptText(gptText);
message.setCreateTime(Instant.now().toEpochMilli());
message.setModel(model);
return message;
}
}

@ -1,18 +1,34 @@
package cn.teammodel.service;
import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import reactor.core.publisher.Flux;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 访DeepSeek
*/
public interface DeepSeekService {
/**
* AI
* @param message
* @return
*/
DeepSeekChatResponse ChatAsk(ChatReqDto message);
/**
* AI
* @param message
* @return
*/
ChatResponseDto ChatAsk(ChatReqDto message);
SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message);
/**
* AI
* @param chatCompletionReqDto
* @return
*/
SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto);
}

@ -1,6 +1,5 @@
package cn.teammodel.service;
import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto;
import cn.teammodel.model.entity.ai.DeepSeekSession;
import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage;

@ -4,6 +4,7 @@ import cn.teammodel.ai.JsonLoader;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.deepseek.DeepSeekClient;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import cn.teammodel.common.ErrorCode;
@ -18,6 +19,7 @@ import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatMessageService;
import cn.teammodel.service.DeepSeekService;
import cn.teammodel.utils.RepositoryUtil;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.TypeReference;
@ -49,12 +51,20 @@ public class ChatMessageServiceImpl implements ChatMessageService {
@Resource
private JsonLoader jsonLoader;
/**
* 访DeepSeek
*/
@Resource
private DeepSeekService deepSeekChatService;
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
//
if (StringUtils.isEmpty(appId) || chatCompletionReqDto.getModel().equals("DeepSeek_Chat")) {
sseEmitter = completionBySession(chatCompletionReqDto, userId);
} else {
sseEmitter = completionByApp(chatCompletionReqDto, false);
@ -165,37 +175,57 @@ public class ChatMessageServiceImpl implements ChatMessageService {
String appPrompt = chatApp.getPrompt();
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件:
if (!justApi) {
// 保存消息记录, 缓存更改
switch (chatCompletionReqDto.getModel()) {
//星火大模型
case "SparkMax":
{
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
SseEmitter finalSseEmitter = sseEmitter;
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(finalSseEmitter, "[DONE]");
// 处理完成后的事件:
if (!justApi) {
// 保存消息记录, 缓存更改
}
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = new ArrayList<>();
messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt,"SparkMax"));
messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt,"SparkMax"));
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(appId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return finalSseEmitter;
}
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = new ArrayList<>();
messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt));
messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt));
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(appId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
// DeepSeek 模型
case "DeepSeek_Chat":
{
// OKHttp 方式请求
sseEmitter = deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto);
//HttpClient 方式请求
//sseEmitter = DeepSeekClient.HttpClientSendRequests(chatCompletionReqDto);
return sseEmitter;
}
default:
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "不支持的模型");
}
}
/**
@ -211,38 +241,53 @@ public class ChatMessageServiceImpl implements ChatMessageService {
}
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt);
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
switch (chatCompletionReqDto.getModel()){
// 星火大模型
case "SparkMax":{
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit");
});
// 对话完成的回调
SseEmitter finalSseEmitter = sseEmitter;
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(finalSseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s,chatCompletionReqDto.getModel());
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(sessionId, PK.of(PK.CHAT_SESSION), ChatSession.class, options);
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(sessionId, userPrompt,chatCompletionReqDto.getModel());
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId(sessionId)
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return finalSseEmitter;
}
// DeepSeek 模型
case "DeepSeek_Chat":
{
sseEmitter = deepSeekChatService.ChatSeeEmitterAsk(chatCompletionReqDto );
return sseEmitter;
}
default:{
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "不支持的模型");
}
}
}
/**
@ -258,7 +303,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
if (sessions.size() == 0) {
// 初始化欢迎语
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," +
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax");
List<ChatSession.Message> history = Collections.singletonList(message);
session = new ChatSession();
session.setId(sessionId);
@ -285,7 +330,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
ChatSession.Message message = ChatSession.Message.of(userPrompt, s,"SparkMax");
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
@ -297,7 +342,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt);
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt,"SparkMax");
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
@ -324,7 +369,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
if (sessions.size() == 0) {
// 初始化欢迎语
ChatSession.Message message = ChatSession.Message.of("", "你好" + userName + " ,我是你的私人 AI 助手小豆," +
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
"你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax");
List<ChatSession.Message> history = Collections.singletonList(message);
session = new ChatSession();
session.setId(sessionId);
@ -351,7 +396,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录, 缓存更改
ChatSession.Message message = ChatSession.Message.of(userPrompt, s);
ChatSession.Message message = ChatSession.Message.of(userPrompt, s,"SparkMax");
HistoryCache.updateContext(sessionId, message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
@ -363,7 +408,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
log.error("callback: ws error, info: " + s);
// 返还积分
});
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt);
List<SparkChatRequestParam.Message> messageList = fetchContext(userId, userPrompt,"SparkMax");
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
@ -374,7 +419,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
return sseEmitter;
}
List<SparkChatRequestParam.Message> fetchContext(String userId, String prompt) {
List<SparkChatRequestParam.Message> fetchContext(String userId, String prompt, String model) {
List<ChatSession.Message> context = HistoryCache.getContext(userId);
List<SparkChatRequestParam.Message> paramMessages = new ArrayList<>();
// 暂未缓存,从数据库拉取
@ -388,10 +433,10 @@ public class ChatMessageServiceImpl implements ChatMessageService {
// convert DB Message to Spark Message
context.forEach(item -> {
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText()));
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText()));
paramMessages.add(SparkChatRequestParam.Message.ofUser(item.getUserText(),model));
paramMessages.add(SparkChatRequestParam.Message.ofAssistant(item.getGptText(),model));
});
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt));
paramMessages.add(SparkChatRequestParam.Message.ofUser(prompt,model));
return paramMessages;
}

@ -35,7 +35,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
@Override
public String createSession(String userId, String name) {
// 初始化欢迎语
Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!","SparkMax");
List<Message> history = Collections.singletonList(message);
ChatSession chatSession = new ChatSession();
chatSession.setId(UUID.randomUUID().toString());

@ -1,47 +1,72 @@
package cn.teammodel.service.impl;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.deepseek.DeepSeekClient;
import cn.teammodel.common.PK;
import cn.teammodel.model.dto.ai.deepseek.ChatRequestOKHttpDto;
import cn.teammodel.model.dto.ai.deepseek.ChatResponseDto;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatRequestDto;
import cn.teammodel.model.dto.ai.deepseek.DeepSeekChatResponse;
import cn.teammodel.model.dto.ai.deepseek.ChatReqDto;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.model.entity.ai.DeepSeekSession;
import cn.teammodel.model.entity.ai.DeepSeekSession.DeepSeekMessage;
import cn.teammodel.repository.ChatSessionRepository;
import cn.teammodel.repository.DeepSeekRepository;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.DeepSeekService;
import cn.teammodel.service.DeepSeekSessionService;
import cn.teammodel.utils.RepositoryUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* 访DeepSeek
*/
@Service
@Slf4j
public class DeepSeekServiceImpl implements DeepSeekService {
@Resource
private DeepSeekSessionService deepSeekService;
@Resource
private DeepSeekRepository deepSeekRepository;
@Resource
private ChatSessionRepository chatSessionRepository;
private final ExecutorService executorService = Executors.newCachedThreadPool();
private final ObjectMapper objectMapper = new ObjectMapper();
/**
*
* @param message
* @return
*/
@Override
public ChatResponseDto ChatAsk(ChatReqDto message) {
public DeepSeekChatResponse ChatAsk(ChatReqDto message) {
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
ChatRequestOKHttpDto requestBody = new ChatRequestOKHttpDto();
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
@ -50,7 +75,7 @@ public class DeepSeekServiceImpl implements DeepSeekService {
//开始时间
long startTime = System.currentTimeMillis();
//发起请求
ChatResponseDto response = DeepSeekClient.SendRequests(requestBody);
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = DeepSeekClient.SendRequests(requestBody);
//Map<String, Object> response = SendRequest(requestBody);
//结束时间
@ -77,6 +102,108 @@ public class DeepSeekServiceImpl implements DeepSeekService {
return response;
}
@Override
public SseEmitter OKHttpChatSeeEmitterAsk(ChatReqDto message) {
SseEmitter se = new SseEmitter(-1L);
//创建消息列表
List<ChatReqDto> msg = new ArrayList<>();
msg.add(message);
//构建请求头
DeepSeekChatRequestDto requestBody = new DeepSeekChatRequestDto();
requestBody.setModel(DeepSeekClient.API_Model);
requestBody.setMessages(msg);
requestBody.setTemperature(0);
requestBody.setMax_tokens(1024);
requestBody.setStream(true);
if (requestBody.getStream()){
se = DeepSeekClient.SendRequestsEmitter(requestBody);
}else {
DeepSeekChatResponse response = DeepSeekClient.SendRequests(requestBody);
}
return se;
}
/**
*
* @param chatCompletionReqDto
* @return
*/
@Override
public SseEmitter ChatSeeEmitterAsk(ChatCompletionReqDto chatCompletionReqDto) {
SseEmitter sseEmitter = new SseEmitter(-1L);
StringBuilder strContent = new StringBuilder();
executorService.execute(()-> {
try {
log.info("流式回答开始,问题:{}", chatCompletionReqDto.getText());
try (CloseableHttpClient client = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(DeepSeekClient.API_Url);
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Authorization", "Bearer " + DeepSeekClient.API_Key);
Map<String, Object> question = new HashMap<>();
question.put("role", "user");
question.put("content", chatCompletionReqDto.getText());
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", DeepSeekClient.API_Model);
requestMap.put("messages", Collections.singletonList(question));
requestMap.put("stream", true);
String requestBody = objectMapper.writeValueAsString(requestMap);
httpPost.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8));
StringBuilder responseBody = new StringBuilder();
try (CloseableHttpResponse response = client.execute(httpPost);
BufferedReader reader = new BufferedReader(
new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String jsonData = line.substring(6);
if ("[DONE]".equals(jsonData)) {
sseEmitter.send("[DONE]");
// 会话完成,更新历史会话记录
ChatSession.Message message = ChatSession.Message.of(chatCompletionReqDto.getText(), strContent.toString(),chatCompletionReqDto.getModel());
HistoryCache.updateContext(chatCompletionReqDto.getSessionId(), message);
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/updateTime", Instant.now().toEpochMilli())
.add("/history/-", message);
chatSessionRepository.save(chatCompletionReqDto.getSessionId(), PK.of(PK.CHAT_SESSION), ChatSession.class, options);
break;
}
JsonNode node = objectMapper.readTree(jsonData);
String content = node.path("choices")
.path(0)
.path("delta")
.path("content")
.asText("");
if (!content.isEmpty()) {
responseBody.append(content);
strContent.append(content);
sseEmitter.send(content);
}
}
}
log.info("流式回答结束,{}",question);
sseEmitter.complete();
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
} catch (Exception e) {
log.error("处理 Deepseek 请求时发生错误", e);
sseEmitter.completeWithError(e);
}
});
return sseEmitter;
}
//region 辅助方法
/**
* /
@ -85,7 +212,7 @@ public class DeepSeekServiceImpl implements DeepSeekService {
* @param savaMessage
* @param response
*/
private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, ChatResponseDto response) {
private void UpdateSession(ChatReqDto message, DeepSeekSession session, DeepSeekMessage savaMessage, DeepSeekChatResponse response) {
if (session.getId() == null){
List<DeepSeekMessage> history = Collections.singletonList(savaMessage);
String userId = SecurityUtil.getLoginUser().getId();

Loading…
Cancel
Save