diff --git a/src/main/java/cn/teammodel/common/PK.java b/src/main/java/cn/teammodel/common/PK.java index 80a6ff7..b0a1d04 100644 --- a/src/main/java/cn/teammodel/common/PK.java +++ b/src/main/java/cn/teammodel/common/PK.java @@ -26,6 +26,7 @@ public interface PK { String CHAT_SESSION = "ChatSession"; String WEEK_DUTY = "Duty"; String WEEK_DUTY_RECORD = "DutyRecord-%s"; + String CHAT_APP = "ChatApp"; /** * 构建分区键 diff --git a/src/main/java/cn/teammodel/controller/frontend/AiController.java b/src/main/java/cn/teammodel/controller/frontend/AiController.java index f3879bd..3bc8021 100644 --- a/src/main/java/cn/teammodel/controller/frontend/AiController.java +++ b/src/main/java/cn/teammodel/controller/frontend/AiController.java @@ -3,8 +3,12 @@ package cn.teammodel.controller.frontend; import cn.teammodel.common.IdRequest; import cn.teammodel.common.R; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; +import cn.teammodel.model.dto.ai.CreateChatAppDto; +import cn.teammodel.model.dto.ai.UpdateChatAppDto; import cn.teammodel.model.dto.ai.UpdateSessionDto; +import cn.teammodel.model.entity.ai.ChatApp; import cn.teammodel.model.entity.ai.ChatSession; +import cn.teammodel.service.ChatAppService; import cn.teammodel.service.ChatMessageService; import cn.teammodel.service.ChatSessionService; import io.swagger.annotations.Api; @@ -26,6 +30,8 @@ public class AiController { private ChatSessionService chatSessionService; @Resource private ChatMessageService chatMessageService; + @Resource + private ChatAppService chatAppService; @PostMapping("chat/completion") @ApiOperation("与 spark 的流式对话") @@ -33,7 +39,7 @@ public class AiController { return chatMessageService.chatCompletion(chatCompletionReqDto); } - @PostMapping("chat/test/completion") +// @PostMapping("chat/test/completion") @ApiOperation("与 spark 的流式对话") public SseEmitter testCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) throws IOException, InterruptedException { SseEmitter sseEmitter = new SseEmitter(); @@ -100,6 +106,36 @@ public class AiController { return R.success(session); } + @GetMapping("app/list") + @ApiOperation("查询聊天应用列表") + public R> listApp() { + List chatApps = chatAppService.listApp(); + return R.success(chatApps); + } + + @PostMapping("app/create") + @ApiOperation("创建聊天应用") + public R createApp(@RequestBody @Valid CreateChatAppDto createChatAppDto) { + ChatApp chatApp = chatAppService.createApp(createChatAppDto); + return R.success(chatApp); + } + + @PostMapping("app/update") + @ApiOperation("更新聊天应用") + public R updateApp(@RequestBody @Valid UpdateChatAppDto updateChatAppDto) { + ChatApp chatApp = chatAppService.updateApp(updateChatAppDto); + return R.success(chatApp); + } + + @PostMapping("app/remove") + @ApiOperation("删除聊天应用") + public R updateApp(@RequestBody @Valid IdRequest idRequest) { + chatAppService.deleteApp(idRequest); + return R.success("删除应用成功"); + } + + + diff --git a/src/main/java/cn/teammodel/model/dto/Appraise/UpdateNodeDto.java b/src/main/java/cn/teammodel/model/dto/Appraise/UpdateNodeDto.java index 7b87987..03b2107 100644 --- a/src/main/java/cn/teammodel/model/dto/Appraise/UpdateNodeDto.java +++ b/src/main/java/cn/teammodel/model/dto/Appraise/UpdateNodeDto.java @@ -17,7 +17,6 @@ public class UpdateNodeDto { @ApiModelProperty(value = "评价项节点的 id") String id; String name; - String[] path; String logo; Integer order; boolean isPraise; diff --git a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java index e4a0ff4..ac163da 100644 --- a/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java +++ b/src/main/java/cn/teammodel/model/dto/ai/ChatCompletionReqDto.java @@ -8,9 +8,9 @@ import javax.validation.constraints.NotBlank; public class ChatCompletionReqDto { private String sessionId; /** - * 预设的会话面具 + * 预设的会话面具 id */ - private Long appId; + private String appId; @NotBlank(message = "请输入消息内容") private String text; } \ No newline at end of file diff --git a/src/main/java/cn/teammodel/model/dto/ai/CreateChatAppDto.java b/src/main/java/cn/teammodel/model/dto/ai/CreateChatAppDto.java new file mode 100644 index 0000000..98e7241 --- /dev/null +++ b/src/main/java/cn/teammodel/model/dto/ai/CreateChatAppDto.java @@ -0,0 +1,26 @@ +package cn.teammodel.model.dto.ai; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import javax.validation.constraints.NotBlank; + +@Data +public class CreateChatAppDto { + @ApiModelProperty("应用图标") + private String icon; + + @ApiModelProperty("应用名称") + @NotBlank(message = "请输入应用名称") + private String name; + + @ApiModelProperty("应用描述") + private String description; + + @ApiModelProperty("是否公开") + private boolean publicApp = false; + + @NotBlank(message = "请输入应用提示词") + @ApiModelProperty("应用提示词") + private String prompt; +} \ No newline at end of file diff --git a/src/main/java/cn/teammodel/model/dto/ai/UpdateChatAppDto.java b/src/main/java/cn/teammodel/model/dto/ai/UpdateChatAppDto.java new file mode 100644 index 0000000..c69e70f --- /dev/null +++ b/src/main/java/cn/teammodel/model/dto/ai/UpdateChatAppDto.java @@ -0,0 +1,30 @@ +package cn.teammodel.model.dto.ai; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import javax.validation.constraints.NotBlank; + +@Data +public class UpdateChatAppDto { + @ApiModelProperty("应用 id") + @NotBlank(message = "请输入应用 id") + private String id; + + @ApiModelProperty("应用图标") + private String icon; + + @ApiModelProperty("应用名称") + @NotBlank(message = "请输入应用名称") + private String name; + + @ApiModelProperty("应用描述") + private String description; + + @ApiModelProperty("是否公开") + private boolean publicApp = false; + + @NotBlank(message = "请输入应用提示词") + @ApiModelProperty("应用提示词") + private String prompt; +} \ No newline at end of file diff --git a/src/main/java/cn/teammodel/model/entity/ai/ChatApp.java b/src/main/java/cn/teammodel/model/entity/ai/ChatApp.java new file mode 100644 index 0000000..1f728d9 --- /dev/null +++ b/src/main/java/cn/teammodel/model/entity/ai/ChatApp.java @@ -0,0 +1,37 @@ +package cn.teammodel.model.entity.ai; + +import cn.teammodel.model.entity.BaseItem; +import com.azure.spring.data.cosmos.core.mapping.Container; +import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** + * 聊天应用(面具) + * code: ChatApp + * @author winter + * @create 2024-01-23 11:08 + */ +@EqualsAndHashCode(callSuper = true) +@Container(containerName = "School") +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ChatApp extends BaseItem { + /** + * 租户 id ( public: 不允许修改 / telnet id) + */ + private String schoolId; + /** + * 图标 + */ + private String icon; + private String name; + private String description; + /** + * 面具提示词 + */ + private String prompt; + private String creator; + private String creatorId; + private Long createTime; +} diff --git a/src/main/java/cn/teammodel/repository/ChatAppRepository.java b/src/main/java/cn/teammodel/repository/ChatAppRepository.java new file mode 100644 index 0000000..3f73df4 --- /dev/null +++ b/src/main/java/cn/teammodel/repository/ChatAppRepository.java @@ -0,0 +1,21 @@ +package cn.teammodel.repository; + +import cn.teammodel.model.entity.ai.ChatApp; +import com.azure.spring.data.cosmos.repository.CosmosRepository; +import com.azure.spring.data.cosmos.repository.Query; +import org.springframework.stereotype.Repository; + +import java.util.List; + +/** + * @author winter + * @create 2023-11-28 17:39 + */ +@Repository +public interface ChatAppRepository extends CosmosRepository { + @Query("SELECT * FROM c WHERE c.code = 'ChatApp' AND (c.schoolId = @schoolId OR c.schoolId = 'public'))") + List findAllByCodeAndSchoolId(String schoolId); + + @Query("SELECT c.id, c.schoolId, c.prompt FROM c WHERE c.code = 'ChatApp' AND c.id = @id") + List findByAppId(String id); +} diff --git a/src/main/java/cn/teammodel/service/ChatAppService.java b/src/main/java/cn/teammodel/service/ChatAppService.java new file mode 100644 index 0000000..7180a94 --- /dev/null +++ b/src/main/java/cn/teammodel/service/ChatAppService.java @@ -0,0 +1,22 @@ +package cn.teammodel.service; + +import cn.teammodel.common.IdRequest; +import cn.teammodel.model.dto.ai.CreateChatAppDto; +import cn.teammodel.model.dto.ai.UpdateChatAppDto; +import cn.teammodel.model.entity.ai.ChatApp; + +import java.util.List; + +/** + * @author winter + * @create 2024-01-23 11:19 + */ +public interface ChatAppService { + ChatApp createApp(CreateChatAppDto createChatAppDto); + + ChatApp updateApp(UpdateChatAppDto updateChatAppDto); + + void deleteApp(IdRequest idRequest); + + List listApp(); +} diff --git a/src/main/java/cn/teammodel/service/impl/ChatAppServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatAppServiceImpl.java new file mode 100644 index 0000000..b5024a0 --- /dev/null +++ b/src/main/java/cn/teammodel/service/impl/ChatAppServiceImpl.java @@ -0,0 +1,85 @@ +package cn.teammodel.service.impl; + +import cn.teammodel.common.ErrorCode; +import cn.teammodel.common.IdRequest; +import cn.teammodel.common.PK; +import cn.teammodel.config.exception.ServiceException; +import cn.teammodel.model.dto.ai.CreateChatAppDto; +import cn.teammodel.model.dto.ai.UpdateChatAppDto; +import cn.teammodel.model.entity.User; +import cn.teammodel.model.entity.ai.ChatApp; +import cn.teammodel.repository.ChatAppRepository; +import cn.teammodel.security.utils.SecurityUtil; +import cn.teammodel.service.ChatAppService; +import cn.teammodel.utils.RepositoryUtil; +import org.springframework.stereotype.Service; + +import javax.annotation.Resource; +import java.time.Instant; +import java.util.List; + +/** + * @author winter + * @create 2024-01-23 11:19 + */ +@Service +public class ChatAppServiceImpl implements ChatAppService { + @Resource + private ChatAppRepository chatAppRepostitory; + + @Override + public ChatApp createApp(CreateChatAppDto createChatAppDto) { + User user = SecurityUtil.getLoginUser(); + ChatApp newApp = new ChatApp(); + newApp.setName(createChatAppDto.getName()); + newApp.setIcon(createChatAppDto.getIcon()); + if (createChatAppDto.isPublicApp()) { + newApp.setSchoolId("public"); + } else { + newApp.setSchoolId(user.getSchoolId()); + } + newApp.setDescription(createChatAppDto.getDescription()); + newApp.setPrompt(createChatAppDto.getPrompt()); + newApp.setCreator(user.getName()); + newApp.setCreatorId(user.getId()); + newApp.setCreateTime(Instant.now().toEpochMilli()); + newApp.setCode(PK.CHAT_APP); + + return chatAppRepostitory.save(newApp); + } + + @Override + public ChatApp updateApp(UpdateChatAppDto updateChatAppDto) { + String id = updateChatAppDto.getId(); + String userId = SecurityUtil.getUserId(); + ChatApp chatApp = RepositoryUtil.findOne(chatAppRepostitory.findByAppId(id), "该应用不存在"); + if (!userId.equals(chatApp.getCreatorId())) { + throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "您没有权限修改该应用"); + } + chatApp.setIcon(updateChatAppDto.getIcon()); + chatApp.setName(updateChatAppDto.getName()); + chatApp.setDescription(updateChatAppDto.getDescription()); + chatApp.setPrompt(updateChatAppDto.getPrompt()); + return chatAppRepostitory.save(chatApp); + } + + @Override + public void deleteApp(IdRequest idRequest) { + String appId = idRequest.getId(); + String userId = SecurityUtil.getUserId(); + ChatApp chatApp = RepositoryUtil.findOne(chatAppRepostitory.findByAppId(appId), "该应用不存在"); + if (userId.equals(chatApp.getCreatorId())) { + chatAppRepostitory.deleteById(appId); + } else { + throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "您没有权限修改该应用"); + } + } + + @Override + public List listApp() { + User user = SecurityUtil.getLoginUser(); + String schoolId = user.getSchoolId(); + List apps = chatAppRepostitory.findAllByCodeAndSchoolId(schoolId); + return apps; + } +} diff --git a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java index dcf6c18..19a7f34 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java @@ -8,6 +8,8 @@ import cn.teammodel.ai.listener.SparkGptStreamListener; import cn.teammodel.common.ErrorCode; import cn.teammodel.common.PK; import cn.teammodel.config.exception.ServiceException; +import cn.teammodel.model.entity.ai.ChatApp; +import cn.teammodel.repository.ChatAppRepository; import cn.teammodel.repository.ChatSessionRepository; import cn.teammodel.model.dto.ai.ChatCompletionReqDto; import cn.teammodel.model.entity.User; @@ -18,6 +20,7 @@ import cn.teammodel.utils.RepositoryUtil; import com.azure.cosmos.models.CosmosPatchOperations; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ObjectUtils; +import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -37,11 +40,81 @@ public class ChatMessageServiceImpl implements ChatMessageService { private SparkGptClient sparkGptClient; @Resource private ChatSessionRepository chatSessionRepository; + @Resource + private ChatAppRepository chatAppRepository; @Override public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { // 目前仅使用讯飞星火大模型 + String appId = chatCompletionReqDto.getAppId(); + SseEmitter sseEmitter; + if (StringUtils.isEmpty(appId)) { + sseEmitter = completionBySession(chatCompletionReqDto); + } else { + sseEmitter = completionByApp(chatCompletionReqDto, false); + } + return sseEmitter; + } + + /** + * 面具模式(暂时不存储聊天记录) + */ + private SseEmitter completionByApp(ChatCompletionReqDto chatCompletionReqDto, boolean justApi) { + String appId = chatCompletionReqDto.getAppId(); + String userPrompt = chatCompletionReqDto.getText(); + User user = SecurityUtil.getLoginUser(); + String userId = user.getId(); + String schoolId = user.getSchoolId(); + + // 查询 appId 获取 prompt + // 通过 prompt 和 userprompt 生成结果 + // 直接返回 + ChatApp chatApp = RepositoryUtil.findOne(chatAppRepository.findByAppId(appId), "该应用不存在"); + // 检验 app 是否可以被该用户使用 + if (!schoolId.equals(chatApp.getSchoolId()) && !"public".equals(chatApp.getSchoolId())) { + throw new ServiceException(ErrorCode.NO_AUTH_ERROR.getCode(), "无权使用该应用"); + } + + String appPrompt = chatApp.getPrompt(); + SseEmitter sseEmitter = new SseEmitter(-1L); + SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter); + // open 回调 + listener.setOnOpen((s) -> { + // 敏感词检查,计费 (设计模型, reducePoints, 或者都可以在完成的回调中做?) + log.info("callback: ws open event emmit"); + }); + // 对话完成的回调 + listener.setOnComplete((s) -> { + log.info("callback: ws complete event emmit"); + SseHelper.send(sseEmitter, "[DONE]"); + // 处理完成后的事件: + if (!justApi) { + // 保存消息记录, 缓存更改 + } + }); + // 错误的回调 + listener.setOnError((s) -> { + log.error("callback: ws error, info: " + s); + // 返还积分 + }); + List messageList = new ArrayList<>(); + messageList.add(SparkChatRequestParam.Message.ofAssistant(appPrompt)); + messageList.add(SparkChatRequestParam.Message.ofUser(userPrompt)); + SparkChatRequestParam requestParam = SparkChatRequestParam + .builder() + .uid(userId) + .chatId(appId) + .messageList(messageList) + .build(); + sparkGptClient.streamChatCompletion(requestParam, listener); + return sseEmitter; + } + + /** + * 会话模式 + */ + private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto) { User user = SecurityUtil.getLoginUser(); String userId = user.getId(); String userPrompt = chatCompletionReqDto.getText();