feat: chat session 的实现

11111
winter 1 year ago
parent fee64cb1ac
commit 6ea7454973

@ -38,9 +38,9 @@ public class SparkGptClient implements InitializingBean {
* *
*/ */
public void init() { public void init() {
String authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret()); authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://"); this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
log.info("鉴权 url: {}", this.authUrl); log.info("[SPARK CHAT] 鉴权 url: {}", this.authUrl);
this.okHttpClient = new OkHttpClient() this.okHttpClient = new OkHttpClient()
.newBuilder() .newBuilder()

@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Builder.Default;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
@ -22,16 +23,16 @@ public class SparkChatRequestParam {
//每个用户的id用于区分不同用户 //每个用户的id用于区分不同用户
private String uid; private String uid;
//指定访问的领域,general指向V1.5版本 generalv2指向V2版本。注意不同的取值对应的url也不一样 //指定访问的领域,general指向V1.5版本 generalv2指向V2版本。注意不同的取值对应的url也不一样
@Builder.Default @Default
private String domain = "generalv3"; private String domain = "generalv3";
//核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高 //核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高
@Builder.Default @Default
private Float temperature = 0.5F; private Float temperature = 0.5F;
//模型回答的tokens的最大长度 //模型回答的tokens的最大长度
@Builder.Default @Default
private Integer maxTokens = 2048; private Integer maxTokens = 2048;
//从k个候选中随机选择⼀个⾮等概率 //从k个候选中随机选择⼀个⾮等概率
@Builder.Default @Default
private Integer top_k = 4; private Integer top_k = 4;
//用于关联用户会话 //用于关联用户会话
private String chatId; private String chatId;

@ -20,7 +20,7 @@ import javax.validation.constraints.NotNull;
import java.util.function.Consumer; import java.util.function.Consumer;
/** /**
* okhttp ws listner * okhttp ws listener
* @author winter * @author winter
* @create 2023-12-15 16:17 * @create 2023-12-15 16:17
*/ */
@ -44,7 +44,6 @@ public class SparkGptStreamListener extends WebSocketListener {
try { try {
onOpen.accept(webSocket); onOpen.accept(webSocket);
} catch (Exception e) { } catch (Exception e) {
// todo: 这儿不应该直接调她
this.onFailure(webSocket, e, response); this.onFailure(webSocket, e, response);
} }
} }
@ -76,7 +75,6 @@ public class SparkGptStreamListener extends WebSocketListener {
onComplete.accept(answer); onComplete.accept(answer);
SseHelper.complete(sseEmitter); SseHelper.complete(sseEmitter);
} }
} }
@Override @Override
@ -89,7 +87,7 @@ public class SparkGptStreamListener extends WebSocketListener {
} }
// 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调 // 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调(可以使用模板方法模式实现)
protected Consumer<WebSocket> onOpen = (s) -> {}; protected Consumer<WebSocket> onOpen = (s) -> {};
protected Consumer<Throwable> onError = (s) -> {}; protected Consumer<Throwable> onError = (s) -> {};
protected Consumer<String> onComplete = (s) -> {}; protected Consumer<String> onComplete = (s) -> {};

@ -23,6 +23,7 @@ public interface PK {
*/ */
String STUDENT = "Base-%s"; String STUDENT = "Base-%s";
String CLASS = "Class-%s"; String CLASS = "Class-%s";
String CHAT_SESSION = "ChatSession";
/** /**
* *

@ -1,17 +1,25 @@
package cn.teammodel.controller.frontend; package cn.teammodel.controller.frontend;
import cn.teammodel.common.IdRequest;
import cn.teammodel.common.R;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto; import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.dto.ai.UpdateSessionDto;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.service.ChatMessageService; import cn.teammodel.service.ChatMessageService;
import cn.teammodel.service.ChatSessionService;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource; import javax.annotation.Resource;
import javax.validation.Valid; import javax.validation.Valid;
import java.util.List;
@RestController @RestController
@RequestMapping("/public/ai") @RequestMapping("/ai")
public class AiController { public class AiController {
@Resource
private ChatSessionService chatSessionService;
@Resource @Resource
private ChatMessageService chatMessageService; private ChatMessageService chatMessageService;
@ -20,9 +28,34 @@ public class AiController {
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) { public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
return chatMessageService.chatCompletion(chatCompletionReqDto); return chatMessageService.chatCompletion(chatCompletionReqDto);
} }
@GetMapping("test/completion")
@ApiOperation("与 spark 的流式对话") @GetMapping("session/my")
public SseEmitter testChatCompletion() { @ApiOperation("查询我的聊天会话")
return chatMessageService.chatCompletion(null); public R<List<ChatSession>> listMySession() {
List<ChatSession> sessions = chatSessionService.listMySession();
return R.success(sessions);
}
@PostMapping("session/create")
@ApiOperation("创建聊天会话")
public R<String> createSession() {
chatSessionService.createSession();
return R.success("创建会话成功");
}
@PostMapping("session/remove")
@ApiOperation("删除聊天会话")
public R<String> removeSession(@RequestBody @Valid IdRequest idRequest) {
chatSessionService.deleteSession(idRequest.getId());
return R.success("删除会话成功");
} }
@PostMapping("session/update")
@ApiOperation("更新聊天会话")
public R<ChatSession> updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) {
ChatSession session = chatSessionService.updateSession(updateSessionDto);
return R.success(session);
}
} }

@ -0,0 +1,21 @@
package cn.teammodel.dao;
import cn.teammodel.model.entity.ai.ChatSession;
import com.azure.spring.data.cosmos.repository.CosmosRepository;
import com.azure.spring.data.cosmos.repository.Query;
import org.springframework.stereotype.Repository;
import java.util.List;
/**
* @author winter
* @create 2023-11-28 17:39
*/
@Repository
public interface ChatSessionRepository extends CosmosRepository<ChatSession, String> {
@Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.sessionId = @sessionId")
List<ChatSession> findBySessionId(String sessionId);
@Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.userId = @userId")
List<ChatSession> findByUserId(String userId);
}

@ -0,0 +1,18 @@
package cn.teammodel.model.dto.ai;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotBlank;
/**
* @author winter
* @create 2023-12-19 15:42
*/
@Data
public class UpdateSessionDto {
@ApiModelProperty(value = "session id", required = true)
@NotBlank
private String id;
private String title;
}

@ -0,0 +1,66 @@
package cn.teammodel.model.entity.ai;
import cn.hutool.core.lang.UUID;
import cn.teammodel.model.entity.BaseItem;
import com.azure.spring.data.cosmos.core.mapping.Container;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.*;
import java.time.Instant;
import java.util.List;
/**
* , teacherId(userId), id: sessionId
* @author winter
* @create 2023-12-19 15:09
*/
@EqualsAndHashCode(callSuper = true)
@Container(containerName = "Teacher")
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatSession extends BaseItem {
/**
*
*/
private String title;
/**
* id
*/
private String userId;
private Long createTime;
/**
* ,
*/
private Long updateTime;
private List<Message> history;
@Data
public static class Message {
private String id;
private String userText;
private String gptText;
/**
* point
*/
private Integer cost;
private Long createTime;
public static Message ofUserText(String userText) {
Message message = new Message();
message.setId(UUID.randomUUID().toString());
message.setCost(0);
message.setUserText(userText);
message.setCreateTime(Instant.now().toEpochMilli());
return message;
}
public static Message ofGptText(String gptText) {
Message message = new Message();
message.setId(UUID.randomUUID().toString());
message.setCost(0);
message.setGptText(gptText);
message.setCreateTime(Instant.now().toEpochMilli());
return message;
}
}
}

@ -0,0 +1,21 @@
package cn.teammodel.service;
import cn.teammodel.model.dto.ai.UpdateSessionDto;
import cn.teammodel.model.entity.ai.ChatSession;
import java.util.List;
/**
* @author winter
* @create 2023-12-19 15:30
*/
public interface ChatSessionService {
void createSession();
List<ChatSession> listMySession();
ChatSession updateSession(UpdateSessionDto updateSessionDto);
void deleteSession(String id);
}

@ -32,14 +32,12 @@ public class ChatMessageServiceImpl implements ChatMessageService {
User user = SecurityUtil.getLoginUser(); User user = SecurityUtil.getLoginUser();
String userId = user.getId(); String userId = user.getId();
String text = chatCompletionReqDto.getText(); String text = chatCompletionReqDto.getText();
// String userId = "123";
// String text = "hello, how should I call you?";
SseEmitter sseEmitter = new SseEmitter(-1L); SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调 // open 回调
listener.setOnOpen((s) -> { listener.setOnOpen((s) -> {
// 敏感词检查,计费 // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?)
log.info("callback: ws open event emmit"); log.info("callback: ws open event emmit");
}); });
// 对话完成的回调 // 对话完成的回调

@ -0,0 +1,91 @@
package cn.teammodel.service.impl;
import cn.hutool.core.lang.UUID;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.dao.ChatSessionRepository;
import cn.teammodel.model.dto.ai.UpdateSessionDto;
import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.model.entity.ai.ChatSession.Message;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatSessionService;
import cn.teammodel.utils.RepositoryUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.time.Instant;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
/**
* @author winter
* @create 2023-12-19 15:31
*/
@Service
@Slf4j
public class ChatSessionServiceImpl implements ChatSessionService {
@Resource
private ChatSessionRepository chatSessionRepository;
@Override
public void createSession() {
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
// 初始化欢迎语
Message message = Message.ofGptText("你好 " + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<Message> history = Collections.singletonList(message);
ChatSession chatSession = new ChatSession();
chatSession.setId(UUID.randomUUID().toString());
chatSession.setCode(PK.CHAT_SESSION);
chatSession.setTitle("新对话");
chatSession.setUserId(userId);
chatSession.setCreateTime(Instant.now().toEpochMilli());
chatSession.setUpdateTime(Instant.now().toEpochMilli());
chatSession.setHistory(history);
chatSessionRepository.save(chatSession);
}
@Override
public List<ChatSession> listMySession() {
String userId = SecurityUtil.getUserId();
List<ChatSession> sessions = chatSessionRepository.findByUserId(userId);
// 按更新时间排序
sessions = sessions.stream().sorted(Comparator.comparing(ChatSession::getUpdateTime)).collect(Collectors.toList());
return sessions;
}
@Override
public ChatSession updateSession(UpdateSessionDto updateSessionDto) {
String id = updateSessionDto.getId();
String title = updateSessionDto.getTitle();
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "");
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR);
}
CosmosPatchOperations options = CosmosPatchOperations.create()
.replace("/title", title);
chatSessionRepository.save(id, PK.of(PK.CHAT_SESSION),ChatSession.class, options);
return null;
}
@Override
public void deleteSession(String id) {
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在");
// 鉴权
if (!session.getUserId().equals(userId)) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR);
}
chatSessionRepository.deleteById(id, PK.of(PK.CHAT_SESSION));
}
}

@ -2,10 +2,7 @@ package cn.teammodel;
import cn.teammodel.common.PK; import cn.teammodel.common.PK;
import cn.teammodel.controller.admin.service.AdminAppraiseService; import cn.teammodel.controller.admin.service.AdminAppraiseService;
import cn.teammodel.dao.AppraiseRecordRepository; import cn.teammodel.dao.*;
import cn.teammodel.dao.AppraiseRepository;
import cn.teammodel.dao.SchoolRepository;
import cn.teammodel.dao.StudentRepository;
import cn.teammodel.manager.DingAlertNotifier; import cn.teammodel.manager.DingAlertNotifier;
import cn.teammodel.model.dto.admin.TimeRangeDto; import cn.teammodel.model.dto.admin.TimeRangeDto;
import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto; import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
@ -45,6 +42,8 @@ class TeamModelExtensionApplicationTests {
@Autowired @Autowired
SchoolRepository schoolRepository; SchoolRepository schoolRepository;
@Autowired @Autowired
ChatSessionRepository chatSessionRepository;
@Autowired
private AppraiseRepository appraiseRepository; private AppraiseRepository appraiseRepository;
@Test @Test
@ -217,4 +216,11 @@ class TeamModelExtensionApplicationTests {
// ruleDto.setUpdateRule(rule); // ruleDto.setUpdateRule(rule);
System.out.println(adminAppraiseService.updateAchieveRule(ruleDto)); System.out.println(adminAppraiseService.updateAchieveRule(ruleDto));
} }
@Test
public void testSelectChatSession() {
System.out.println(chatSessionRepository.findByUserId("1595321354"));
}
} }

Loading…
Cancel
Save