diff --git a/src/main/java/cn/teammodel/ai/SparkGptClient.java b/src/main/java/cn/teammodel/ai/SparkGptClient.java index ff21814..4616e06 100644 --- a/src/main/java/cn/teammodel/ai/SparkGptClient.java +++ b/src/main/java/cn/teammodel/ai/SparkGptClient.java @@ -38,9 +38,9 @@ public class SparkGptClient implements InitializingBean { * 静态构造对象方法 */ public void init() { - String authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret()); + authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret()); this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://"); - log.info("鉴权 url: {}", this.authUrl); + log.info("[SPARK CHAT] 鉴权 url: {}", this.authUrl); this.okHttpClient = new OkHttpClient() .newBuilder() diff --git a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java index 45aa1c8..1a12831 100644 --- a/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java +++ b/src/main/java/cn/teammodel/ai/domain/SparkChatRequestParam.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.Builder.Default; import lombok.Data; import java.util.List; @@ -22,16 +23,16 @@ public class SparkChatRequestParam { //每个用户的id,用于区分不同用户 private String uid; //指定访问的领域,general指向V1.5版本 generalv2指向V2版本。注意:不同的取值对应的url也不一样! - @Builder.Default + @Default private String domain = "generalv3"; //核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高 - @Builder.Default + @Default private Float temperature = 0.5F; //模型回答的tokens的最大长度 - @Builder.Default + @Default private Integer maxTokens = 2048; //从k个候选中随机选择⼀个(⾮等概率) - @Builder.Default + @Default private Integer top_k = 4; //用于关联用户会话 private String chatId; diff --git a/src/main/java/cn/teammodel/ai/listener/SparkGptStreamListener.java b/src/main/java/cn/teammodel/ai/listener/SparkGptStreamListener.java index 013deae..66937a0 100644 --- a/src/main/java/cn/teammodel/ai/listener/SparkGptStreamListener.java +++ b/src/main/java/cn/teammodel/ai/listener/SparkGptStreamListener.java @@ -20,7 +20,7 @@ import javax.validation.constraints.NotNull; import java.util.function.Consumer; /** - * okhttp 调用 ws 接口时的 listner + * okhttp 调用 ws 接口时的 listener * @author winter * @create 2023-12-15 16:17 */ @@ -44,7 +44,6 @@ public class SparkGptStreamListener extends WebSocketListener { try { onOpen.accept(webSocket); } catch (Exception e) { - // todo: 这儿不应该直接调她 this.onFailure(webSocket, e, response); } } @@ -76,7 +75,6 @@ public class SparkGptStreamListener extends WebSocketListener { onComplete.accept(answer); SseHelper.complete(sseEmitter); } - } @Override @@ -89,7 +87,7 @@ public class SparkGptStreamListener extends WebSocketListener { } - // 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调 + // 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调(可以使用模板方法模式实现) protected Consumer onOpen = (s) -> {}; protected Consumer onError = (s) -> {}; protected Consumer onComplete = (s) -> {}; diff --git a/src/main/java/cn/teammodel/common/PK.java b/src/main/java/cn/teammodel/common/PK.java index 69eec63..599750c 100644 --- a/src/main/java/cn/teammodel/common/PK.java +++ b/src/main/java/cn/teammodel/common/PK.java @@ -23,6 +23,7 @@ public interface PK { */ String STUDENT = "Base-%s"; String CLASS = "Class-%s"; + String CHAT_SESSION = "ChatSession"; /** * 构建分区键 diff --git a/src/main/java/cn/teammodel/controller/frontend/AiController.java b/src/main/java/cn/teammodel/controller/frontend/AiController.java index 45e05c6..d58a2fa 100644 --- a/src/main/java/cn/teammodel/controller/frontend/AiController.java +++ b/src/main/java/cn/teammodel/controller/frontend/AiController.java @@ -1,17 +1,25 @@ package cn.teammodel.controller.frontend; +import cn.teammodel.common.IdRequest; +import cn.teammodel.common.R; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.UpdateSessionDto; +import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.service.ChatMessageService; +import cn.teammodel.service.ChatSessionService; import io.swagger.annotations.ApiOperation; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import javax.validation.Valid; +import java.util.List; @RestController -@RequestMapping("/public/ai") +@RequestMapping("/ai") public class AiController { + @Resource + private ChatSessionService chatSessionService; @Resource private ChatMessageService chatMessageService; @@ -20,9 +28,34 @@ public class AiController { public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) { return chatMessageService.chatCompletion(chatCompletionReqDto); } - @GetMapping("test/completion") - @ApiOperation("与 spark 的流式对话") - public SseEmitter testChatCompletion() { - return chatMessageService.chatCompletion(null); + + @GetMapping("session/my") + @ApiOperation("查询我的聊天会话") + public R> listMySession() { + List sessions = chatSessionService.listMySession(); + return R.success(sessions); + } + @PostMapping("session/create") + @ApiOperation("创建聊天会话") + public R createSession() { + chatSessionService.createSession(); + return R.success("创建会话成功"); + } + + @PostMapping("session/remove") + @ApiOperation("删除聊天会话") + public R removeSession(@RequestBody @Valid IdRequest idRequest) { + chatSessionService.deleteSession(idRequest.getId()); + return R.success("删除会话成功"); } + @PostMapping("session/update") + @ApiOperation("更新聊天会话") + public R updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) { + ChatSession session = chatSessionService.updateSession(updateSessionDto); + return R.success(session); + } + + + + } \ No newline at end of file diff --git a/src/main/java/cn/teammodel/dao/ChatSessionRepository.java b/src/main/java/cn/teammodel/dao/ChatSessionRepository.java new file mode 100644 index 0000000..013e4f7 --- /dev/null +++ b/src/main/java/cn/teammodel/dao/ChatSessionRepository.java @@ -0,0 +1,21 @@ +package cn.teammodel.dao; + +import cn.teammodel.model.entity.ai.ChatSession; +import com.azure.spring.data.cosmos.repository.CosmosRepository; +import com.azure.spring.data.cosmos.repository.Query; +import org.springframework.stereotype.Repository; + +import java.util.List; + +/** + * @author winter + * @create 2023-11-28 17:39 + */ +@Repository +public interface ChatSessionRepository extends CosmosRepository { + @Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.sessionId = @sessionId") + List findBySessionId(String sessionId); + + @Query("select c.id, c.code, c.title, c.userId, c.createTime from c where c.code = 'ChatSession' and c.userId = @userId") + List findByUserId(String userId); +} diff --git a/src/main/java/cn/teammodel/model/dto/ai/UpdateSessionDto.java b/src/main/java/cn/teammodel/model/dto/ai/UpdateSessionDto.java new file mode 100644 index 0000000..e641ed1 --- /dev/null +++ b/src/main/java/cn/teammodel/model/dto/ai/UpdateSessionDto.java @@ -0,0 +1,18 @@ +package cn.teammodel.model.dto.ai; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import javax.validation.constraints.NotBlank; + +/** + * @author winter + * @create 2023-12-19 15:42 + */ +@Data +public class UpdateSessionDto { + @ApiModelProperty(value = "session id", required = true) + @NotBlank + private String id; + private String title; +} diff --git a/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java new file mode 100644 index 0000000..dfb5542 --- /dev/null +++ b/src/main/java/cn/teammodel/model/entity/ai/ChatSession.java @@ -0,0 +1,66 @@ +package cn.teammodel.model.entity.ai; + +import cn.hutool.core.lang.UUID; +import cn.teammodel.model.entity.BaseItem; +import com.azure.spring.data.cosmos.core.mapping.Container; +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.*; + +import java.time.Instant; +import java.util.List; + +/** + * 聊天会话,绑定 teacherId(userId), 主键id: sessionId + * @author winter + * @create 2023-12-19 15:09 + */ +@EqualsAndHashCode(callSuper = true) +@Container(containerName = "Teacher") +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ChatSession extends BaseItem { + /** + * 会话名称 + */ + private String title; + /** + * 用户 id + */ + private String userId; + private Long createTime; + /** + * 产生对话即更新时间,按更新时间排序 + */ + private Long updateTime; + private List history; + + @Data + public static class Message { + private String id; + private String userText; + private String gptText; + /** + * 消耗的 point + */ + private Integer cost; + private Long createTime; + + public static Message ofUserText(String userText) { + Message message = new Message(); + message.setId(UUID.randomUUID().toString()); + message.setCost(0); + message.setUserText(userText); + message.setCreateTime(Instant.now().toEpochMilli()); + return message; + } + + public static Message ofGptText(String gptText) { + Message message = new Message(); + message.setId(UUID.randomUUID().toString()); + message.setCost(0); + message.setGptText(gptText); + message.setCreateTime(Instant.now().toEpochMilli()); + return message; + } + } +} diff --git a/src/main/java/cn/teammodel/service/ChatSessionService.java b/src/main/java/cn/teammodel/service/ChatSessionService.java new file mode 100644 index 0000000..17ea615 --- /dev/null +++ b/src/main/java/cn/teammodel/service/ChatSessionService.java @@ -0,0 +1,21 @@ +package cn.teammodel.service; + +import cn.teammodel.model.dto.ai.UpdateSessionDto; +import cn.teammodel.model.entity.ai.ChatSession; + +import java.util.List; + +/** + * @author winter + * @create 2023-12-19 15:30 + */ +public interface ChatSessionService { + + void createSession(); + + List listMySession(); + + ChatSession updateSession(UpdateSessionDto updateSessionDto); + + void deleteSession(String id); +} diff --git a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java index d695fc7..db65933 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java @@ -32,14 +32,12 @@ public class ChatMessageServiceImpl implements ChatMessageService { User user = SecurityUtil.getLoginUser(); String userId = user.getId(); String text = chatCompletionReqDto.getText(); -// String userId = "123"; -// String text = "hello, how should I call you?"; SseEmitter sseEmitter = new SseEmitter(-1L); SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); // open 回调 listener.setOnOpen((s) -> { - // 敏感词检查,计费 + // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) log.info("callback: ws open event emmit"); }); // 对话完成的回调 diff --git a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java new file mode 100644 index 0000000..2f272b7 --- /dev/null +++ b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java @@ -0,0 +1,91 @@ +package cn.teammodel.service.impl; + +import cn.hutool.core.lang.UUID; +import cn.teammodel.common.ErrorCode; +import cn.teammodel.common.PK; +import cn.teammodel.config.exception.ServiceException; +import cn.teammodel.dao.ChatSessionRepository; +import cn.teammodel.model.dto.ai.UpdateSessionDto; +import cn.teammodel.model.entity.User; +import cn.teammodel.model.entity.ai.ChatSession; +import cn.teammodel.model.entity.ai.ChatSession.Message; +import cn.teammodel.security.utils.SecurityUtil; +import cn.teammodel.service.ChatSessionService; +import cn.teammodel.utils.RepositoryUtil; +import com.azure.cosmos.models.CosmosPatchOperations; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import javax.annotation.Resource; +import java.time.Instant; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author winter + * @create 2023-12-19 15:31 + */ +@Service +@Slf4j +public class ChatSessionServiceImpl implements ChatSessionService { + @Resource + private ChatSessionRepository chatSessionRepository; + + @Override + public void createSession() { + User user = SecurityUtil.getLoginUser(); + String userId = user.getId(); + // 初始化欢迎语 + Message message = Message.ofGptText("你好 " + user.getName() + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + List history = Collections.singletonList(message); + ChatSession chatSession = new ChatSession(); + chatSession.setId(UUID.randomUUID().toString()); + chatSession.setCode(PK.CHAT_SESSION); + chatSession.setTitle("新对话"); + chatSession.setUserId(userId); + chatSession.setCreateTime(Instant.now().toEpochMilli()); + chatSession.setUpdateTime(Instant.now().toEpochMilli()); + chatSession.setHistory(history); + chatSessionRepository.save(chatSession); + } + + @Override + public List listMySession() { + String userId = SecurityUtil.getUserId(); + List sessions = chatSessionRepository.findByUserId(userId); + // 按更新时间排序 + sessions = sessions.stream().sorted(Comparator.comparing(ChatSession::getUpdateTime)).collect(Collectors.toList()); + return sessions; + } + + @Override + public ChatSession updateSession(UpdateSessionDto updateSessionDto) { + String id = updateSessionDto.getId(); + String title = updateSessionDto.getTitle(); + User user = SecurityUtil.getLoginUser(); + String userId = user.getId(); + + ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), ""); + if (!session.getUserId().equals(userId)) { + throw new ServiceException(ErrorCode.NO_AUTH_ERROR); + } + CosmosPatchOperations options = CosmosPatchOperations.create() + .replace("/title", title); + chatSessionRepository.save(id, PK.of(PK.CHAT_SESSION),ChatSession.class, options); + return null; + } + + @Override + public void deleteSession(String id) { + User user = SecurityUtil.getLoginUser(); + String userId = user.getId(); + ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在"); + // 鉴权 + if (!session.getUserId().equals(userId)) { + throw new ServiceException(ErrorCode.NO_AUTH_ERROR); + } + chatSessionRepository.deleteById(id, PK.of(PK.CHAT_SESSION)); + } +} diff --git a/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java b/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java index 064994d..8d0dfb8 100644 --- a/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java +++ b/src/test/java/cn/teammodel/TeamModelExtensionApplicationTests.java @@ -2,10 +2,7 @@ package cn.teammodel; import cn.teammodel.common.PK; import cn.teammodel.controller.admin.service.AdminAppraiseService; -import cn.teammodel.dao.AppraiseRecordRepository; -import cn.teammodel.dao.AppraiseRepository; -import cn.teammodel.dao.SchoolRepository; -import cn.teammodel.dao.StudentRepository; +import cn.teammodel.dao.*; import cn.teammodel.manager.DingAlertNotifier; import cn.teammodel.model.dto.admin.TimeRangeDto; import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto; @@ -45,6 +42,8 @@ class TeamModelExtensionApplicationTests { @Autowired SchoolRepository schoolRepository; @Autowired + ChatSessionRepository chatSessionRepository; + @Autowired private AppraiseRepository appraiseRepository; @Test @@ -217,4 +216,11 @@ class TeamModelExtensionApplicationTests { // ruleDto.setUpdateRule(rule); System.out.println(adminAppraiseService.updateAchieveRule(ruleDto)); } + + @Test + public void testSelectChatSession() { + System.out.println(chatSessionRepository.findByUserId("1595321354")); + } + + }