feat: 初步走通 spark Gpt 的对接流程,留下扩展其他 AI 的空间

11111
winter 1 year ago
parent 0bae71207b
commit fee64cb1ac

@ -37,6 +37,16 @@
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId> <artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>3.14.9</version>
</dependency>
<!-- cosmos --> <!-- cosmos -->
<!-- <dependency>--> <!-- <dependency>-->

@ -0,0 +1,115 @@
package cn.teammodel.ai;
import cn.hutool.json.JSONUtil;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* spark Gpt client
* @author winter
* @create 2023-12-15 14:29
*/
@Component
@Data
@Slf4j
public class SparkGptClient implements InitializingBean {
@Resource
private SparkGptProperties sparkGptProperties;
private OkHttpClient okHttpClient;
private String authUrl;
/**
*
*/
public void init() {
String authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
log.info("鉴权 url: {}", this.authUrl);
this.okHttpClient = new OkHttpClient()
.newBuilder()
.connectTimeout(90, TimeUnit.SECONDS)
.readTimeout(90, TimeUnit.SECONDS)
.writeTimeout(90, TimeUnit.SECONDS)
.build();
}
/**
* , sse
*/
public void streamChatCompletion(SparkChatRequestParam param, SparkGptStreamListener listener) {
try {
param.setAppId(sparkGptProperties.getAppId());
Request request = new Request.Builder().url(authUrl).build();
// 设置请求参数
listener.setRequestJson(param.toJsonParams());
log.info("请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
okHttpClient.newWebSocket(request, listener);
} catch (Exception e) {
log.error("Spark AI 请求异常: {}", e.getMessage());
e.printStackTrace();
}
}
/**
* URL
*/
public static String genAuthUrl(String endpoint, String apiKey, String apiSecret) {
URL url = null;
String date = null;
String preStr = null;
Mac mac = null;
try {
url = new URL(endpoint);
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
date = format.format(new Date());
// 拼接
preStr = "host: " + url.getHost() + "\n" +
"date: " + date + "\n" +
"GET " + url.getPath() + " HTTP/1.1";
// SHA256加密
mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
} catch (Exception e) {
log.error("生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
throw new RuntimeException(e);
}
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// 拼接
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
// 拼接地址
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
return httpUrl.toString();
}
@Override
public void afterPropertiesSet() throws Exception {
init();
}
}

@ -0,0 +1,20 @@
package cn.teammodel.ai;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* @author winter
* @create 2023-12-15 14:29
*/
@Data
@Configuration
@ConfigurationProperties(prefix = "spark.gpt")
public class SparkGptProperties {
private String endpoint;
private String appId;
private String apiKey;
private String apiSecret;
}

@ -0,0 +1,30 @@
package cn.teammodel.ai;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@UtilityClass
@Slf4j
public class SseHelper {
/**
* server client sse
*/
public void complete(SseEmitter sseEmitter) {
try {
sseEmitter.complete();
} catch (Exception e) {
}
}
/**
* client
*/
public void send(SseEmitter sseEmitter, Object data) {
try {
sseEmitter.send(data);
} catch (Exception e) {
log.error("sseEmitter send error", e);
}
}
}

@ -0,0 +1,83 @@
package cn.teammodel.ai.domain;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import java.util.List;
/**
* endpoint
* @author winter
* @create 2023-12-15 16:04
*/
@Data
@Builder
// 注意这两个注解一起使用会让无参构造丢失
public class SparkChatRequestParam {
//应用appid从开放平台控制台创建的应用中获取
private String appId;
//每个用户的id用于区分不同用户
private String uid;
//指定访问的领域,general指向V1.5版本 generalv2指向V2版本。注意不同的取值对应的url也不一样
@Builder.Default
private String domain = "generalv3";
//核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高
@Builder.Default
private Float temperature = 0.5F;
//模型回答的tokens的最大长度
@Builder.Default
private Integer maxTokens = 2048;
//从k个候选中随机选择⼀个⾮等概率
@Builder.Default
private Integer top_k = 4;
//用于关联用户会话
private String chatId;
private List<Message> messageList;
@Data
@AllArgsConstructor
public static class Message {
private String role;
private String content;
/**
* ,使,
*/
public static Message ofUser(String content){
return new Message("user",content);
}
public static Message ofAssistant(String content){
return new Message("assistant",content);
}
}
public String toJsonParams(){
ObjectMapper om = new ObjectMapper();
ObjectNode root = om.createObjectNode();
ObjectNode header = om.createObjectNode();
header.put("app_id",appId);
header.put("uid",uid);
ObjectNode parameter = om.createObjectNode();
ObjectNode chat = om.createObjectNode();
chat.put("domain", domain);
chat.put("temperature", temperature);
chat.put("max_tokens", maxTokens);
chat.put("top_k", top_k);
chat.put("chat_id", chatId);
parameter.set("chat", chat);
ObjectNode payload = om.createObjectNode();
payload.set("message", om.createObjectNode().putPOJO("text", messageList));
root.set("header", header);
root.set("parameter", parameter);
root.set("payload", payload);
return root.toString();
}
}

@ -0,0 +1,77 @@
package cn.teammodel.ai.domain;
import lombok.Data;
import java.util.List;
/**
* spark ws
* @Author: winter
*/
@Data
public class SparkChatResponse {
private Header header;
private Payload payload;
@Data
public static class Header{
/**
* 00
*/
private Integer code;
private String message;
private String sid;
/**
* [0,1,2]012
*/
private Integer status;
}
@Data
public static class Payload{
private Choices choices;
private Usage usage;
}
@Data
public static class Choices{
private Integer status;
private Integer seq;
private List<Text> text;
}
@Data
public static class Usage{
private UsageText text;
}
@Data
public static class UsageText{
private Integer question_tokens;
/**
* tokens
*/
private Integer prompt_tokens;
/**
* tokens
*/
private Integer completion_tokens;
/**
* prompt_tokenscompletion_tokenstokens
*/
private Integer total_tokens;
}
@Data
public static class Text{
/**
* AI
*/
private String content;
/**
* assistantAI
*/
private String role;
private Integer index;
}
}

@ -0,0 +1,96 @@
package cn.teammodel.ai.listener;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.domain.SparkChatResponse;
import cn.teammodel.common.ErrorCode;
import cn.teammodel.config.exception.ServiceException;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Nullable;
import javax.validation.constraints.NotNull;
import java.util.function.Consumer;
/**
* okhttp ws listner
* @author winter
* @create 2023-12-15 16:17
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@RequiredArgsConstructor
public class SparkGptStreamListener extends WebSocketListener {
private final SseEmitter sseEmitter;
private String requestJson;
/**
*
*/
private String answer = "";
@Override
public void onOpen(WebSocket webSocket, @NotNull Response response) {
// ws 建立连接后发送请求的数据, 在 onMessage 中接受 server 的响应数据
webSocket.send(requestJson);
// 执行成功回调
try {
onOpen.accept(webSocket);
} catch (Exception e) {
// todo: 这儿不应该直接调她
this.onFailure(webSocket, e, response);
}
}
@Override
public void onMessage(WebSocket webSocket, String text) {
// 向 sse 推送消息(一次 onMessage 事件推送一次)
SparkChatResponse sparkChatResponse = JSONUtil.toBean(text, SparkChatResponse.class);
// 请求是否异常
if (sparkChatResponse.getHeader().getCode() != 0) {
this.onFailure(
webSocket,
new ServiceException(ErrorCode.SYSTEM_ERROR.getCode(),
sparkChatResponse.getHeader().getMessage()),
null
);
}
// 推送回答 segment
String msgSegment = sparkChatResponse.getPayload().getChoices().getText().get(0).getContent();
// 处理消息格式(空格和换行符)
msgSegment = StrUtil.replace(msgSegment, " ", "&#32;").replaceAll("\n", "&#92n");
answer += msgSegment;
SseHelper.send(sseEmitter, msgSegment);
// 处理模型的最终响应
if (sparkChatResponse.getHeader().getStatus() == 2) {
// 其实 spark 会主动断开连接
webSocket.close(1000, "done");
onComplete.accept(answer);
SseHelper.complete(sseEmitter);
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
webSocket.close(1000, t.getMessage());
this.onError.accept(t);
// 失败时结束 sse 连接
SseHelper.send(sseEmitter,t.getMessage() + "[DONE]");
SseHelper.complete(sseEmitter);
}
// 这几个 function 可以在 listener 被调用时设置, 实现类似事件的回调
protected Consumer<WebSocket> onOpen = (s) -> {};
protected Consumer<Throwable> onError = (s) -> {};
protected Consumer<String> onComplete = (s) -> {};
}

@ -6,4 +6,5 @@ package cn.teammodel.common;
*/ */
public interface CommonConstant { public interface CommonConstant {
String DASH = "-"; String DASH = "-";
} }

@ -0,0 +1,13 @@
package cn.teammodel.common;
/**
* @author winter
* @create 2023-12-14 12:19
*/
public interface FiveEducations {
String MORAL = "美德";
String INTELLECTUAL = "美智";
String PHYSICAL = "体育";
String AESTHETIC = "美艺";
String LABOUR = "劳动";
}

@ -0,0 +1,32 @@
package cn.teammodel.controller.admin.controller;
import cn.teammodel.common.R;
import cn.teammodel.controller.admin.service.AdminAppraiseService;
import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
import cn.teammodel.model.entity.appraise.AchievementRule;
import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.util.List;
/**
* @author winter
* @create 2023-12-13 16:07
*/
@RestController
@RequestMapping("admin/appraise")
public class AdminAppraiseController {
@Resource
private AdminAppraiseService adminAppraiseService;
@PostMapping("updateAchieveRule")
@ApiOperation("更新的 rule 节点将会直接覆盖老节点")
public R<List<AchievementRule>> updateAchieveRule(@RequestBody UpdateAchievementRuleDto ruleDto) {
List<AchievementRule> res = adminAppraiseService.updateAchieveRule(ruleDto);
return R.success(res);
}
}

@ -1,6 +1,8 @@
package cn.teammodel.controller.admin.service; package cn.teammodel.controller.admin.service;
import cn.teammodel.model.dto.admin.TimeRangeDto; import cn.teammodel.model.dto.admin.TimeRangeDto;
import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
import cn.teammodel.model.entity.appraise.AchievementRule;
import cn.teammodel.model.vo.admin.IndexData; import cn.teammodel.model.vo.admin.IndexData;
import cn.teammodel.model.vo.admin.RankPo; import cn.teammodel.model.vo.admin.RankPo;
import cn.teammodel.model.vo.admin.RankVo; import cn.teammodel.model.vo.admin.RankVo;
@ -28,4 +30,6 @@ public interface AdminAppraiseService {
List<RankPo> appraiseNodeRank(TimeRangeDto timeRangeDto); List<RankPo> appraiseNodeRank(TimeRangeDto timeRangeDto);
List<StudentRankVo> studentRank(TimeRangeDto timeRangeDto); List<StudentRankVo> studentRank(TimeRangeDto timeRangeDto);
List<AchievementRule> updateAchieveRule(UpdateAchievementRuleDto ruleDto);
} }

@ -7,7 +7,10 @@ import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.controller.admin.service.AdminAppraiseService; import cn.teammodel.controller.admin.service.AdminAppraiseService;
import cn.teammodel.dao.*; import cn.teammodel.dao.*;
import cn.teammodel.model.dto.admin.TimeRangeDto; import cn.teammodel.model.dto.admin.TimeRangeDto;
import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
import cn.teammodel.model.entity.User; import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.appraise.AchievementRule;
import cn.teammodel.model.entity.appraise.Appraise;
import cn.teammodel.model.entity.school.ClassInfo; import cn.teammodel.model.entity.school.ClassInfo;
import cn.teammodel.model.entity.school.School; import cn.teammodel.model.entity.school.School;
import cn.teammodel.model.entity.school.Student; import cn.teammodel.model.entity.school.Student;
@ -18,10 +21,14 @@ import cn.teammodel.model.vo.admin.RankVo;
import cn.teammodel.model.vo.admin.StudentRankVo; import cn.teammodel.model.vo.admin.StudentRankVo;
import cn.teammodel.model.vo.appraise.RecordVo; import cn.teammodel.model.vo.appraise.RecordVo;
import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.utils.RepositoryUtil;
import cn.teammodel.utils.SchoolDateUtil; import cn.teammodel.utils.SchoolDateUtil;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.azure.cosmos.models.PartitionKey;
import com.azure.spring.data.cosmos.core.query.CosmosPageRequest; import com.azure.spring.data.cosmos.core.query.CosmosPageRequest;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.data.domain.Slice; import org.springframework.data.domain.Slice;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -46,10 +53,9 @@ public class AdminAppraiseServiceImpl implements AdminAppraiseService {
@Resource @Resource
private TeacherRepository teacherRepository; private TeacherRepository teacherRepository;
@Resource @Resource
private StudentRepository studentRepository;
@Resource
private AppraiseRepository appraiseRepository; private AppraiseRepository appraiseRepository;
@Resource
private StudentRepository studentRepository;
@Resource @Resource
private AppraiseRecordRepository appraiseRecordRepository; private AppraiseRecordRepository appraiseRecordRepository;
@ -130,7 +136,7 @@ public class AdminAppraiseServiceImpl implements AdminAppraiseService {
String academicYearId = timeRangeDto.getAcademicYearId(); String academicYearId = timeRangeDto.getAcademicYearId();
String schoolId = SecurityUtil.getLoginUser().getSchoolId(); String schoolId = SecurityUtil.getLoginUser().getSchoolId();
// fixme: 是否对时间范围做一些限制 // fixme: 是否对时间范围做一些限制(不能确保当前周有数据)
// 无参默认当前周 // 无参默认当前周
if (startTime == null || endTime == null) { if (startTime == null || endTime == null) {
// 将时间范围调整为当前周的周一到当前时间 // 将时间范围调整为当前周的周一到当前时间
@ -150,6 +156,10 @@ public class AdminAppraiseServiceImpl implements AdminAppraiseService {
endTime endTime
); );
if (res != null) {
res = res.stream().sorted((o1, o2) -> o2.getCreateTime().compareTo(o1.getCreateTime())).collect(Collectors.toList());
}
return res; return res;
} }
@ -167,6 +177,7 @@ public class AdminAppraiseServiceImpl implements AdminAppraiseService {
startTime, startTime,
endTime endTime
); );
if (ObjectUtils.isEmpty(rankPoList)) return null; if (ObjectUtils.isEmpty(rankPoList)) return null;
Set<String> classIdSet = rankPoList.stream().map(RankPo::getId).collect(Collectors.toSet()); Set<String> classIdSet = rankPoList.stream().map(RankPo::getId).collect(Collectors.toSet());
// 注意: 如果查询 in 的查询集在数据库中不存在,则在结果集也不会为 null. // 注意: 如果查询 in 的查询集在数据库中不存在,则在结果集也不会为 null.
@ -307,4 +318,46 @@ public class AdminAppraiseServiceImpl implements AdminAppraiseService {
return res; return res;
} }
@Override
public List<AchievementRule> updateAchieveRule(UpdateAchievementRuleDto ruleDto) {
String periodId = ruleDto.getPeriodId();
UpdateAchievementRuleDto.UpdateRule updateRule = ruleDto.getUpdateRule();
// fixme: 判断 参数
if (ObjectUtils.isEmpty(updateRule) || StringUtils.isBlank(updateRule.getId())) {
throw new ServiceException(ErrorCode.PARAMS_ERROR.getCode(), "rule id 不能为空");
}
User user = SecurityUtil.getLoginUser();
String schoolId = user.getSchoolId();
// String schoolId = "template";
Appraise appraise = RepositoryUtil.findOne(appraiseRepository.findRulesById(schoolId, periodId), "参数错误,找不到该学段下的评价规则");
List<AchievementRule> rules = appraise.getAchievementRules();
if (ObjectUtils.isEmpty(rules)) {
throw new ServiceException(ErrorCode.OPERATION_ERROR.getCode(), "该学段暂无没有成就规则");
}
// sort rules by level
rules = rules.stream().sorted(Comparator.comparing(AchievementRule::getLevel)).collect(Collectors.toList());
boolean flag = false;
int lastPromotionCount = 0;
for (int i = 0, rulesSize = rules.size(); i < rulesSize; i++) {
AchievementRule rule = rules.get(i);
if (updateRule.getId().equals(rule.getId())) {
BeanUtils.copyProperties(updateRule, rule);
lastPromotionCount = rule.getLevelCount() * rule.getPromotionLevel();
rule.setPromotionCount(lastPromotionCount);
flag = true;
continue;
}
// 处理后面的节点,将 promotionCount 依次修改
if (flag) {
lastPromotionCount = lastPromotionCount + rule.getLevelCount() * rule.getPromotionLevel();
rule.setPromotionCount(lastPromotionCount);
}
}
CosmosPatchOperations operations = CosmosPatchOperations.create().replace("/achievementRules", rules);
Appraise saved = appraiseRepository.save(appraise.getId(), new PartitionKey(PK.PK_APPRAISE), Appraise.class, operations);
return saved.getAchievementRules();
}
} }

@ -0,0 +1,28 @@
package cn.teammodel.controller.frontend;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.service.ChatMessageService;
import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import javax.validation.Valid;
@RestController
@RequestMapping("/public/ai")
public class AiController {
@Resource
private ChatMessageService chatMessageService;
@PostMapping("chat/completion")
@ApiOperation("与 spark 的流式对话")
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
return chatMessageService.chatCompletion(chatCompletionReqDto);
}
@GetMapping("test/completion")
@ApiOperation("与 spark 的流式对话")
public SseEmitter testChatCompletion() {
return chatMessageService.chatCompletion(null);
}
}

@ -1,9 +1,11 @@
package cn.teammodel.controller; package cn.teammodel.controller.frontend;
import cn.teammodel.common.IdRequest;
import cn.teammodel.common.R; import cn.teammodel.common.R;
import cn.teammodel.model.dto.Appraise.*; import cn.teammodel.model.dto.Appraise.*;
import cn.teammodel.model.entity.appraise.Appraise; import cn.teammodel.model.entity.appraise.Appraise;
import cn.teammodel.model.vo.appraise.AppraiseRecordVo; import cn.teammodel.model.vo.appraise.AppraiseRecordVo;
import cn.teammodel.model.vo.appraise.StudentReportVo;
import cn.teammodel.service.EvaluationService; import cn.teammodel.service.EvaluationService;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
@ -71,5 +73,13 @@ public class AppraiseController {
return R.success(res); return R.success(res);
} }
@PostMapping("studentReport")
@ApiOperation(value = "查看学生当前的学期的实时评价报告")
public R<StudentReportVo> studentReport(@Valid @RequestBody IdRequest idRequest) {
StudentReportVo res = evaluationService.studentReport(idRequest);
return R.success(res);
}
} }

@ -1,32 +1,32 @@
package cn.teammodel.controller; package cn.teammodel.controller.frontend;
import cn.teammodel.common.R; import cn.teammodel.common.R;
import cn.teammodel.dao.AppraiseRepository; import cn.teammodel.dao.AppraiseRepository;
import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource; import javax.annotation.Resource;
@RestController @RestController
@RequestMapping("/") @RequestMapping("/")
public class HelloController { public class HelloController {
@Resource @Resource
private AppraiseRepository appraiseRepository; private AppraiseRepository appraiseRepository;
@GetMapping("hello") @GetMapping("hello")
@PreAuthorize("@ss.hasRole('admin')") @PreAuthorize("@ss.hasRole('admin')")
public R<String> hello() { public R<String> hello() {
System.out.println(SecurityContextHolder.getContext().getAuthentication()); System.out.println(SecurityContextHolder.getContext().getAuthentication());
return new R(200, "success","hello world"); return new R(200, "success","hello world");
} }
@GetMapping("public/free") @GetMapping("public/free")
@PreAuthorize("permitAll()") @PreAuthorize("permitAll()")
public R<String> free() { public R<String> free() {
return new R(200, "success","hello world"); return new R(200, "success","hello world");
} }
} }

@ -111,6 +111,7 @@ public interface AppraiseRecordRepository extends CosmosRepository<AppraiseRecor
"group by c.targetId, n.appraiseNode.isPraise") "group by c.targetId, n.appraiseNode.isPraise")
List<RankPo> studentRank(String code, String academicYearId, Long startTime, Long endTime); List<RankPo> studentRank(String code, String academicYearId, Long startTime, Long endTime);
// test script
@Query("select * from Student as c where c.code = @code") @Query("select * from Student as c where c.code = @code")
List<AppraiseRecord> findByCode(String code); List<AppraiseRecord> findByCode(String code);
} }

@ -33,7 +33,8 @@ public interface AppraiseRepository extends CosmosRepository<Appraise, String> {
List<Appraise> findTemplateTree(); List<Appraise> findTemplateTree();
@Query("SELECT value n FROM School AS s join n in s.nodes where s.code = @code and n.id = @nodeId") @Query("SELECT value n FROM School AS s join n in s.nodes where s.code = @code and n.id = @nodeId")
List<AppraiseTreeNode> findNodeById(@Param("code") String code, @Param("nodeId") String nodeId); List<AppraiseTreeNode> findNodeById(@Param("code") String code, @Param("nodeId") String nodeId);
@Query("select n.id, n.name from School as c join n in c.nodes where c.code = @code and n.id in (@ids)") @Query("select n.id, n.name from School as c join n in c.nodes where c.code = @code and n.id in (@ids)")
List<AppraiseTreeNode> findAllByCodeAndIdIn(String code, Set<String> ids); List<AppraiseTreeNode> findAllByCodeAndIdIn(String code, Set<String> ids);
@Query("SELECT c.id, c.achievementRules FROM School AS c where c.code = 'Appraise' and c.schoolId = @schoolId and c.periodId = @periodId")
List<Appraise> findRulesById(String schoolId, String periodId);
} }

@ -11,6 +11,8 @@ import java.util.Set;
@Repository @Repository
public interface StudentRepository extends CosmosRepository<Student, String> { public interface StudentRepository extends CosmosRepository<Student, String> {
@Deprecated
// 似乎会出现 irs 的重复问题,建议使用下面的 findByIdAndCode
Student findStudentByIdAndCode(String id, String code); Student findStudentByIdAndCode(String id, String code);
@Query("select c.pk, c.code, c.id, c.name, c.gender, c.schoolId, c.periodId, c.year, c.createTime, c.picture, c.mail, c.mobile, c.country, c.classId, c.no, c.groupId, c.groupName, c.guardians, c.irs, c.salt from Student as c where c.id = @id and c.code = @code") @Query("select c.pk, c.code, c.id, c.name, c.gender, c.schoolId, c.periodId, c.year, c.createTime, c.picture, c.mail, c.mobile, c.country, c.classId, c.no, c.groupId, c.groupName, c.guardians, c.irs, c.salt from Student as c where c.id = @id and c.code = @code")

@ -0,0 +1,45 @@
package cn.teammodel.model.dto.admin;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotNull;
/**
* @author winter
* @create 2023-12-13 17:34
*/
@Data
public class UpdateAchievementRuleDto {
@NotNull
@ApiModelProperty("学段 id")
private String periodId;
@ApiModelProperty("更新的 rule 节点: 将会直接覆盖老节点")
private UpdateRule updateRule;
@Data
public static class UpdateRule {
@NotNull
private String id;
/**
*
*/
@NotNull
private String name;
/**
* logo
*/
@NotNull
private String logo;
/**
*
*/
@NotNull
private Integer levelCount;
/**
*
*/
@NotNull
private Integer promotionLevel;
}
}

@ -0,0 +1,16 @@
package cn.teammodel.model.dto.ai;
import lombok.Data;
import javax.validation.constraints.NotBlank;
@Data
public class ChatCompletionReqDto {
private Long sessionId;
/**
*
*/
private Long appId;
@NotBlank(message = "请输入消息内容")
private String text;
}

@ -0,0 +1,38 @@
package cn.teammodel.model.entity.appraise;
import lombok.Data;
/**
*
* @author winter
* @create 2023-12-13 15:23
*/
@Data
public class AchievementRule {
private String id;
/**
*
*/
private String name;
/**
* logo
*/
private String logo;
/**
*
*/
private Integer level;
/**
*
*/
private Integer levelCount;
/**
*
*/
private Integer promotionLevel;
/**
*
*/
private Integer promotionCount;
}

@ -26,7 +26,12 @@ public class Appraise extends BaseItem {
* id ( default, template) * id ( default, template)
*/ */
private String periodId; private String periodId;
/**
*
*/
private List<AppraiseTreeNode> nodes; private List<AppraiseTreeNode> nodes;
/**
*
*/
private List<AchievementRule> achievementRules;
} }

@ -0,0 +1,16 @@
package cn.teammodel.model.entity.school;
import lombok.Data;
/**
*
* @author winter
* @create 2023-12-13 11:44
*/
@Data
public class SchoolConfig {
/**
*
*/
private String scoreName = "醍摩豆";
}

@ -0,0 +1,20 @@
package cn.teammodel.model.vo.appraise;
import cn.teammodel.model.entity.appraise.AchievementRule;
import lombok.Data;
import java.util.Map;
/**
* vo
* @author winter
* @create 2023-12-04 15:26
*/
@Data
public class StudentReportVo {
private Integer praiseCount;
private Integer score;
private Map<String, Integer> praiseDistribution;
private Map<String, Integer> criticalDistribution;
private AchievementRule curAchievement;
}

@ -0,0 +1,15 @@
package cn.teammodel.service;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* @author winter
* @create 2023-12-18 15:20
*/
public interface ChatMessageService {
/**
* AI
*/
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto);
}

@ -1,9 +1,11 @@
package cn.teammodel.service; package cn.teammodel.service;
import cn.teammodel.common.IdRequest;
import cn.teammodel.model.dto.Appraise.*; import cn.teammodel.model.dto.Appraise.*;
import cn.teammodel.model.entity.appraise.Appraise; import cn.teammodel.model.entity.appraise.Appraise;
import cn.teammodel.model.entity.appraise.AppraiseTreeNode; import cn.teammodel.model.entity.appraise.AppraiseTreeNode;
import cn.teammodel.model.vo.appraise.AppraiseRecordVo; import cn.teammodel.model.vo.appraise.AppraiseRecordVo;
import cn.teammodel.model.vo.appraise.StudentReportVo;
import java.util.List; import java.util.List;
@ -52,4 +54,9 @@ public interface EvaluationService {
List<AppraiseRecordVo> findVoteRecord(FindVoteRecordDto findVoteRecordDto); List<AppraiseRecordVo> findVoteRecord(FindVoteRecordDto findVoteRecordDto);
void recallVote(RecallVoteDto recallVoteDto); void recallVote(RecallVoteDto recallVoteDto);
/**
*
*/
StudentReportVo studentReport(IdRequest idRequest);
} }

@ -0,0 +1,69 @@
package cn.teammodel.service.impl;
import cn.teammodel.ai.SparkGptClient;
import cn.teammodel.ai.SseHelper;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import cn.teammodel.model.dto.ai.ChatCompletionReqDto;
import cn.teammodel.model.entity.User;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatMessageService;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.util.List;
/**
* @author winter
* @create 2023-12-18 15:20
*/
@Service
@Slf4j
public class ChatMessageServiceImpl implements ChatMessageService {
@Resource
private SparkGptClient sparkGptClient;
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
// 目前仅使用讯飞星火大模型
User user = SecurityUtil.getLoginUser();
String userId = user.getId();
String text = chatCompletionReqDto.getText();
// String userId = "123";
// String text = "hello, how should I call you?";
SseEmitter sseEmitter = new SseEmitter(-1L);
SparkGptStreamListener listener = new SparkGptStreamListener(sseEmitter);
// open 回调
listener.setOnOpen((s) -> {
// 敏感词检查,计费
log.info("callback: ws open event emmit");
});
// 对话完成的回调
listener.setOnComplete((s) -> {
log.info("callback: ws complete event emmit");
SseHelper.send(sseEmitter, "[DONE]");
// 处理完成后的事件: 保存消息记录
});
// 错误的回调
listener.setOnError((s) -> {
log.error("callback: ws error" );
// 返还积分
});
// todo: 拉取对话上下文
List<SparkChatRequestParam.Message> messageList = Lists.newArrayList();
messageList.add(SparkChatRequestParam.Message.ofUser(text));
// todo: sessionId
SparkChatRequestParam requestParam = SparkChatRequestParam
.builder()
.uid(userId)
.chatId("123")
.messageList(messageList)
.build();
sparkGptClient.streamChatCompletion(requestParam, listener);
return sseEmitter;
}
}

@ -2,19 +2,18 @@ package cn.teammodel.service.impl;
import cn.hutool.core.lang.UUID; import cn.hutool.core.lang.UUID;
import cn.teammodel.common.ErrorCode; import cn.teammodel.common.ErrorCode;
import cn.teammodel.common.IdRequest;
import cn.teammodel.common.PK; import cn.teammodel.common.PK;
import cn.teammodel.config.exception.ServiceException; import cn.teammodel.config.exception.ServiceException;
import cn.teammodel.dao.*; import cn.teammodel.dao.*;
import cn.teammodel.model.dto.Appraise.*; import cn.teammodel.model.dto.Appraise.*;
import cn.teammodel.model.entity.User; import cn.teammodel.model.entity.User;
import cn.teammodel.model.entity.appraise.Appraise; import cn.teammodel.model.entity.appraise.*;
import cn.teammodel.model.entity.appraise.AppraiseRecord;
import cn.teammodel.model.entity.appraise.AppraiseRecordItem;
import cn.teammodel.model.entity.appraise.AppraiseTreeNode;
import cn.teammodel.model.entity.school.ClassInfo; import cn.teammodel.model.entity.school.ClassInfo;
import cn.teammodel.model.entity.school.School; import cn.teammodel.model.entity.school.School;
import cn.teammodel.model.entity.school.Student; import cn.teammodel.model.entity.school.Student;
import cn.teammodel.model.vo.appraise.AppraiseRecordVo; import cn.teammodel.model.vo.appraise.AppraiseRecordVo;
import cn.teammodel.model.vo.appraise.StudentReportVo;
import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.EvaluationService; import cn.teammodel.service.EvaluationService;
import cn.teammodel.utils.RepositoryUtil; import cn.teammodel.utils.RepositoryUtil;
@ -82,7 +81,6 @@ public class EvaluationServiceImpl implements EvaluationService {
User loginUser = SecurityUtil.getLoginUser(); User loginUser = SecurityUtil.getLoginUser();
String schoolId = loginUser.getSchoolId(); String schoolId = loginUser.getSchoolId();
Appraise appraise = appraiseRepository.findAppraiseBySchoolIdAndPeriodIdAndCode(schoolId, periodId, PK.PK_APPRAISE); Appraise appraise = appraiseRepository.findAppraiseBySchoolIdAndPeriodIdAndCode(schoolId, periodId, PK.PK_APPRAISE);
// todo: 是否要对学段进行鉴权
if (appraise != null) { if (appraise != null) {
return this.buildTree(appraise); return this.buildTree(appraise);
} }
@ -91,7 +89,8 @@ public class EvaluationServiceImpl implements EvaluationService {
if (appraise == null) { if (appraise == null) {
throw new ServiceException(); throw new ServiceException();
} }
refreshAppraiseTree(appraise.getNodes()); // refresh
refreshAppraiseTree(appraise);
appraise.setPeriodId(periodId); appraise.setPeriodId(periodId);
appraise.setSchoolId(schoolId); appraise.setSchoolId(schoolId);
appraise.setId(null); appraise.setId(null);
@ -99,10 +98,13 @@ public class EvaluationServiceImpl implements EvaluationService {
return this.buildTree(appraise); return this.buildTree(appraise);
} }
/** /**
* id () * appraise ()
*/ */
private void refreshAppraiseTree(List<AppraiseTreeNode> nodes) { private void refreshAppraiseTree(Appraise appraise) {
List<AppraiseTreeNode> nodes = appraise.getNodes();
List<AchievementRule> rules = appraise.getAchievementRules();
List<AppraiseTreeNode> children = nodes.stream().filter(item -> item.getPid() != null).collect(Collectors.toList()); List<AppraiseTreeNode> children = nodes.stream().filter(item -> item.getPid() != null).collect(Collectors.toList());
// 将非 root 的 nodes 通过 pid 收集成 list, 再遍历 nodes, 将其每一个 id 对应 map 中的 list 的 pid修改成新的 id 即可 // 将非 root 的 nodes 通过 pid 收集成 list, 再遍历 nodes, 将其每一个 id 对应 map 中的 list 的 pid修改成新的 id 即可
Map<String, List<AppraiseTreeNode>> pidNodeMap = children.stream().collect(Collectors.groupingBy(AppraiseTreeNode::getPid)); Map<String, List<AppraiseTreeNode>> pidNodeMap = children.stream().collect(Collectors.groupingBy(AppraiseTreeNode::getPid));
@ -118,6 +120,8 @@ public class EvaluationServiceImpl implements EvaluationService {
item.setCreator("template"); item.setCreator("template");
}); });
// refresh rules
rules.forEach(item -> item.setId(UUID.randomUUID().toString()));
} }
@Override @Override
@ -328,9 +332,10 @@ public class EvaluationServiceImpl implements EvaluationService {
} else { } else {
CosmosPatchOperations operations = CosmosPatchOperations.create(); CosmosPatchOperations operations = CosmosPatchOperations.create();
operations.add("/nodes/0", item); operations.add("/nodes/0", item);
// 表扬 // 表扬 (待改进不会减少表扬数)
long praise = appraiseTreeNode.isPraise() ? 1 : -1; if (appraiseTreeNode.isPraise()) {
operations.increment("/praiseCount", praise); operations.increment("/praiseCount", 1);
}
// 加分 // 加分
int scoreToPlus = ObjectUtils.isEmpty(appraiseTreeNode.getScore()) ? 0 : appraiseTreeNode.getScore(); int scoreToPlus = ObjectUtils.isEmpty(appraiseTreeNode.getScore()) ? 0 : appraiseTreeNode.getScore();
operations.increment("/score", scoreToPlus); operations.increment("/score", scoreToPlus);
@ -401,6 +406,67 @@ public class EvaluationServiceImpl implements EvaluationService {
appraiseRecordRepository.save(appraiseRecord); appraiseRecordRepository.save(appraiseRecord);
} }
@Override
public StudentReportVo studentReport(IdRequest idRequest) {
String studentId = idRequest.getId();
User user = SecurityUtil.getLoginUser();
String schoolId = user.getSchoolId();
// 查询学生信息,学段等
Student student = RepositoryUtil.findOne(studentRepository.findByIdAndCode(studentId, String.format(PK.STUDENT, schoolId)), "当前学生不存在");
String periodId = student.getPeriodId();
String classId = student.getClassId();
List<School.Semester> semesters = schoolRepository.findSemestersById(schoolId, periodId);
// 生成学年 id
String academicYearId = SchoolDateUtil.calculateAcademicYearId(semesters, LocalDate.now());
// 查询评价文档
AppraiseRecord appraiseRecord = appraiseRecordRepository.findAppraiseRecordByTargetIdAndClassIdAndAcademicYearIdAndCode(
studentId,
classId,
academicYearId,
String.format(PK.PK_APPRAISE_RECORD, schoolId)
);
if (appraiseRecord == null || appraiseRecord.getNodes() == null) {
return null;
}
List<AppraiseRecordItem> records = appraiseRecord.getNodes();
// 查询成就规则
Appraise appraise = RepositoryUtil.findOne(appraiseRepository.findRulesById(schoolId, periodId), "当前成就规则还未创建");
List<AchievementRule> rules = appraise.getAchievementRules();
StudentReportVo reportVo = new StudentReportVo();
// 计算雷达图
Map<String, Integer> praiseDistribution = new HashMap<>();
Map<String, Integer> criticalDistribution = new HashMap<>();
for (AppraiseRecordItem record : records) {
AppraiseTreeNode appraiseNode = record.getAppraiseNode();
String[] path = appraiseNode.getPath();
String root = path[0];
if (appraiseNode.isPraise()) {
praiseDistribution.put(root, praiseDistribution.getOrDefault(root, 0) + 1);
} else {
criticalDistribution.put(root, criticalDistribution.getOrDefault(root, 0) + 1);
}
}
// 计算成就项 (排序
rules = rules.stream().sorted(Comparator.comparing(AchievementRule::getLevel).reversed()).collect(Collectors.toList());
Integer praiseCount = appraiseRecord.getPraiseCount();
AchievementRule curAchievement = rules.get(0);
for (AchievementRule rule : rules) {
Integer promotionCount = rule.getPromotionCount();
int flag = praiseCount / promotionCount;
// 说明当前规则成就匹配
if (flag >= 1) {
curAchievement = rule;
break;
}
}
reportVo.setPraiseCount(appraiseRecord.getPraiseCount());
reportVo.setScore(appraiseRecord.getScore());
reportVo.setPraiseDistribution(praiseDistribution);
reportVo.setCriticalDistribution(criticalDistribution);
reportVo.setCurAchievement(curAchievement);
return reportVo;
}
/** /**
* id id () * id id ()
*/ */

@ -17,8 +17,6 @@ spring:
key: JTUVk92Gjsx17L0xqxn0X4wX2thDPMKiw4daeTyV1HzPb6JmBeHdtFY1MF1jdctW1ofgzqkDMFOtcqS46by31A== key: JTUVk92Gjsx17L0xqxn0X4wX2thDPMKiw4daeTyV1HzPb6JmBeHdtFY1MF1jdctW1ofgzqkDMFOtcqS46by31A==
populate-query-metrics: true populate-query-metrics: true
security: security:
oauth2: oauth2:
resourceserver: resourceserver:
@ -26,6 +24,14 @@ spring:
issuer-uri: https://login.partner.microsoftonline.cn/4807e9cf-87b8-4174-aa5b-e76497d7392b/v2.0 issuer-uri: https://login.partner.microsoftonline.cn/4807e9cf-87b8-4174-aa5b-e76497d7392b/v2.0
audiences: 72643704-b2e7-4b26-b881-bd5865e7a7a5 audiences: 72643704-b2e7-4b26-b881-bd5865e7a7a5
spark:
gpt:
endpoint: https://spark-api.xf-yun.com/v3.1/chat
appId: c49d1e24
apiKey: 6c586e7dd1721ed1bb19bdb573b4ad34
apiSecret: MDU1MTU1Nzg4MDg2ZTJjZWU3MmI4ZGU1
jwt: jwt:
secret: fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA= secret: fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=

@ -8,10 +8,8 @@ import cn.teammodel.dao.SchoolRepository;
import cn.teammodel.dao.StudentRepository; import cn.teammodel.dao.StudentRepository;
import cn.teammodel.manager.DingAlertNotifier; import cn.teammodel.manager.DingAlertNotifier;
import cn.teammodel.model.dto.admin.TimeRangeDto; import cn.teammodel.model.dto.admin.TimeRangeDto;
import cn.teammodel.model.entity.appraise.Appraise; import cn.teammodel.model.dto.admin.UpdateAchievementRuleDto;
import cn.teammodel.model.entity.appraise.AppraiseRecord; import cn.teammodel.model.entity.appraise.*;
import cn.teammodel.model.entity.appraise.AppraiseRecordItem;
import cn.teammodel.model.entity.appraise.AppraiseTreeNode;
import cn.teammodel.model.entity.school.School; import cn.teammodel.model.entity.school.School;
import cn.teammodel.model.vo.admin.RankPo; import cn.teammodel.model.vo.admin.RankPo;
import cn.teammodel.service.EvaluationService; import cn.teammodel.service.EvaluationService;
@ -207,8 +205,16 @@ class TeamModelExtensionApplicationTests {
} }
} }
// @Test @Test
// public void batchUpdateTimeFormat() { public void batchAppraiseSelect() {
// UpdateAchievementRuleDto ruleDto = new UpdateAchievementRuleDto();
// } ruleDto.setPeriodId("template");
AchievementRule rule = new AchievementRule();
rule.setId("1");
rule.setLevelCount(10);
rule.setPromotionLevel(5);
rule.setLogo("https://www.baidu.com");
// ruleDto.setUpdateRule(rule);
System.out.println(adminAppraiseService.updateAchieveRule(ruleDto));
}
} }

Loading…
Cancel
Save