diff --git a/src/main/java/cn/teammodel/controller/frontend/AiController.java b/src/main/java/cn/teammodel/controller/frontend/AiController.java index db380ec..848ff97 100644 --- a/src/main/java/cn/teammodel/controller/frontend/AiController.java +++ b/src/main/java/cn/teammodel/controller/frontend/AiController.java @@ -3,13 +3,16 @@ package cn.teammodel.controller.frontend; import cn.teammodel.common.IdRequest; import cn.teammodel.common.R; import cn.teammodel.model.dto.ai.*; +import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.model.entity.ai.ChatApp; import cn.teammodel.model.entity.ai.ChatSession; +import cn.teammodel.security.utils.SecurityUtil; import cn.teammodel.service.ChatAppService; import cn.teammodel.service.ChatMessageService; import cn.teammodel.service.ChatSessionService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; +import org.apache.commons.lang3.StringUtils; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -33,7 +36,8 @@ public class AiController { @PostMapping("chat/completion") @ApiOperation("与 spark 的流式对话") public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) { - return chatMessageService.chatCompletion(chatCompletionReqDto); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + return chatMessageService.chatCompletion(chatCompletionReqDto, userId); } // @PostMapping("chat/test/completion") @@ -73,33 +77,40 @@ public class AiController { @GetMapping("session/my") @ApiOperation("查询我的聊天会话") public R> listMySession() { - List sessions = chatSessionService.listMySession(); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + List sessions = chatSessionService.listMySession(userId); return R.success(sessions); } @GetMapping("chat/history/{sessionId}") @ApiOperation("查询我的聊天记录") public R> getHistory(@PathVariable String sessionId) { - List history = chatSessionService.listMyHistory(sessionId); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + List history = chatSessionService.listHistory(sessionId, userId); return R.success(history); } @PostMapping("session/create") @ApiOperation("创建聊天会话") public R createSession() { - String sessionId = chatSessionService.createSession(); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + String name = (String) ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().get("name"); + name = StringUtils.isBlank(name) ? "老师" : name; + String sessionId = chatSessionService.createSession(userId, name); return R.success(sessionId); } @PostMapping("session/remove") @ApiOperation("删除聊天会话") public R removeSession(@RequestBody @Valid IdRequest idRequest) { - chatSessionService.deleteSession(idRequest.getId()); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + chatSessionService.deleteSession(idRequest.getId(), userId); return R.success("删除会话成功"); } @PostMapping("session/update") @ApiOperation("更新聊天会话") public R updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) { - ChatSession session = chatSessionService.updateSession(updateSessionDto); + String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject(); + ChatSession session = chatSessionService.updateSession(updateSessionDto, userId); return R.success(session); } diff --git a/src/main/java/cn/teammodel/model/entity/TmdUserDetail.java b/src/main/java/cn/teammodel/model/entity/TmdUserDetail.java index a9a4362..345db3c 100644 --- a/src/main/java/cn/teammodel/model/entity/TmdUserDetail.java +++ b/src/main/java/cn/teammodel/model/entity/TmdUserDetail.java @@ -1,5 +1,6 @@ package cn.teammodel.model.entity; +import io.jsonwebtoken.Claims; import lombok.Data; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.UserDetails; @@ -14,6 +15,8 @@ import java.util.Collection; @Data public class TmdUserDetail implements UserDetails { private User user; + + private Claims claims; @Override public Collection getAuthorities() { return null; diff --git a/src/main/java/cn/teammodel/security/SecurityConfiguration.java b/src/main/java/cn/teammodel/security/SecurityConfiguration.java index 17114fa..7903910 100644 --- a/src/main/java/cn/teammodel/security/SecurityConfiguration.java +++ b/src/main/java/cn/teammodel/security/SecurityConfiguration.java @@ -31,10 +31,6 @@ public class SecurityConfiguration { private RestAccessDeniedHandler restAccessDeniedHandler; @Resource private RestAuthenticationEntryPoint restAuthenticationEntryPoint; - @Resource - private AuthInnerTokenFilter authInnerTokenFilter; - @Resource - private ApiAuthTokenFilter apiAuthTokenFilter; @Bean @Order(2) @@ -65,7 +61,7 @@ public class SecurityConfiguration { .anyRequest().authenticated() ) .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) // 启用 OIDC jwt filter - .addFilterAfter(authInnerTokenFilter, BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter + .addFilterAfter(new AuthInnerTokenFilter(), BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter .exceptionHandling() .authenticationEntryPoint(restAuthenticationEntryPoint) .accessDeniedHandler(restAccessDeniedHandler); @@ -74,13 +70,25 @@ public class SecurityConfiguration { @Bean @Order(1) public SecurityFilterChain outterApiFilterChain(HttpSecurity http) throws Exception { - http. - antMatcher("/ai/api/**") + http + // CSRF禁用,因为不使用session + .csrf().disable() + .cors().configurationSource(corsConfigurationSource()) + .and() + // 禁用HTTP响应标头 + .headers().cacheControl().disable() + .and() + .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS) + .and() + .antMatcher("/ai/api/**") .authorizeRequests(authorizeRequests -> authorizeRequests .anyRequest().authenticated() ) - .addFilterAfter(apiAuthTokenFilter, BearerTokenAuthenticationFilter.class); + .addFilterAfter(new ApiAuthTokenFilter(), BearerTokenAuthenticationFilter.class) + .exceptionHandling() + .authenticationEntryPoint(restAuthenticationEntryPoint) + .accessDeniedHandler(restAccessDeniedHandler); return http.build(); } diff --git a/src/main/java/cn/teammodel/security/filter/ApiAuthTokenFilter.java b/src/main/java/cn/teammodel/security/filter/ApiAuthTokenFilter.java index 92636f1..e6ff9ae 100644 --- a/src/main/java/cn/teammodel/security/filter/ApiAuthTokenFilter.java +++ b/src/main/java/cn/teammodel/security/filter/ApiAuthTokenFilter.java @@ -1,17 +1,15 @@ package cn.teammodel.security.filter; +import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.security.utils.JwtTokenUtil; -import io.jsonwebtoken.Claims; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.jetbrains.annotations.NotNull; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; -import javax.annotation.Resource; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -23,31 +21,30 @@ import java.io.IOException; * @author winter * @create 2023-11-09 10:43 */ -@Component @Slf4j public class ApiAuthTokenFilter extends OncePerRequestFilter { - @Resource - private JwtTokenUtil jwtTokenUtil; + JwtTokenUtil jwtTokenUtil = new JwtTokenUtil(); - // todo: 修改 context 的值 + 写一下多过滤器链的复盘 @Override protected void doFilterInternal(HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException { SecurityContext context = SecurityContextHolder.getContext(); + // 进入此过滤器说明 OIDC 认证成功,则验证 authToken // 验证 authToken 合法 String token = request.getHeader("token"); if (StringUtils.isBlank(token)) { filterChain.doFilter(request, response); return; } - Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000); - if (claims == null) { +// Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000); + TmdUserDetail validUserDetail1 = jwtTokenUtil.getOutterTokenDetail(request); + if (validUserDetail1 == null) { SecurityContextHolder.clearContext(); // 验证失败不应该在此处抛出异常,应该维护好 context 的值,以便整个过滤器链正常运行 filterChain.doFilter(request, response); return; } // 组装 authToken 的 jwt 进 authentication - UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(claims, null, null); + UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(validUserDetail1, null, null); context.setAuthentication(finalAuthentication); filterChain.doFilter(request, response); } diff --git a/src/main/java/cn/teammodel/security/filter/AuthInnerTokenFilter.java b/src/main/java/cn/teammodel/security/filter/AuthInnerTokenFilter.java index 169da02..0fcc04e 100644 --- a/src/main/java/cn/teammodel/security/filter/AuthInnerTokenFilter.java +++ b/src/main/java/cn/teammodel/security/filter/AuthInnerTokenFilter.java @@ -3,13 +3,11 @@ package cn.teammodel.security.filter; import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.security.utils.JwtTokenUtil; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; @@ -24,12 +22,9 @@ import java.util.Collection; * @author winter * @create 2023-11-09 10:43 */ -@Component @Slf4j public class AuthInnerTokenFilter extends OncePerRequestFilter { - - @Autowired - JwtTokenUtil jwtTokenUtil; + JwtTokenUtil jwtTokenUtil = new JwtTokenUtil(); @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { diff --git a/src/main/java/cn/teammodel/security/utils/JwtTokenUtil.java b/src/main/java/cn/teammodel/security/utils/JwtTokenUtil.java index 902467c..2f067ae 100644 --- a/src/main/java/cn/teammodel/security/utils/JwtTokenUtil.java +++ b/src/main/java/cn/teammodel/security/utils/JwtTokenUtil.java @@ -1,15 +1,14 @@ package cn.teammodel.security.utils; -import cn.teammodel.model.entity.User; +import cn.hutool.extra.spring.SpringUtil; import cn.teammodel.model.entity.TmdUserDetail; +import cn.teammodel.model.entity.User; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Value; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.stereotype.Component; import javax.crypto.spec.SecretKeySpec; import javax.servlet.http.HttpServletRequest; @@ -22,13 +21,15 @@ import java.util.stream.Collectors; * @date 2022年11月24日 下午10:50 * @description 生成jwt令牌的工具类 */ -@Component @Slf4j public class JwtTokenUtil { - private static final long NEVER_EXPIRE = 315360000; // 没有永不过期的api: 让时钟偏移十年 + private static final long NEVER_EXPIRE = 315360000 ; // 没有永不过期的api: 让时钟偏移十年 - @Value("${jwt.secret}") - private String secret; + private final String secret; + + { + secret = SpringUtil.getProperty("jwt.secret"); + } /** * 生成token @@ -159,6 +160,22 @@ public class JwtTokenUtil { return tmdUserDetail; } + public TmdUserDetail getOutterTokenDetail(HttpServletRequest request) { + String token = request.getHeader("token"); + if (StringUtils.isBlank(token)) { + return null; + } + Claims claims = getClaimsFromToken(token); + if (claims == null) { + return null; + } + + // 组装 TmdUserDetail + TmdUserDetail tmdUserDetail = new TmdUserDetail(); + tmdUserDetail.setClaims(claims); + return tmdUserDetail; + } + @SuppressWarnings("unchecked") private Set convertToArray(Object o) { if (o == null) { diff --git a/src/main/java/cn/teammodel/security/utils/SecurityUtil.java b/src/main/java/cn/teammodel/security/utils/SecurityUtil.java index 448c60b..c61e3ca 100644 --- a/src/main/java/cn/teammodel/security/utils/SecurityUtil.java +++ b/src/main/java/cn/teammodel/security/utils/SecurityUtil.java @@ -49,8 +49,7 @@ public class SecurityUtil /** * 获取用户 **/ - public static User getLoginUser() - { + public static User getLoginUser() { try { return ((TmdUserDetail) getAuthentication().getPrincipal()).getUser(); diff --git a/src/main/java/cn/teammodel/service/ChatMessageService.java b/src/main/java/cn/teammodel/service/ChatMessageService.java index 6cc867e..8c68fab 100644 --- a/src/main/java/cn/teammodel/service/ChatMessageService.java +++ b/src/main/java/cn/teammodel/service/ChatMessageService.java @@ -11,5 +11,5 @@ public interface ChatMessageService { /** * AI 聊天 */ - SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto); + SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId); } diff --git a/src/main/java/cn/teammodel/service/ChatSessionService.java b/src/main/java/cn/teammodel/service/ChatSessionService.java index 79ccf37..b185875 100644 --- a/src/main/java/cn/teammodel/service/ChatSessionService.java +++ b/src/main/java/cn/teammodel/service/ChatSessionService.java @@ -11,13 +11,13 @@ import java.util.List; */ public interface ChatSessionService { - String createSession(); + String createSession(String userId, String name); - List listMySession(); + List listMySession(String userId); - ChatSession updateSession(UpdateSessionDto updateSessionDto); + ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId); - void deleteSession(String id); + void deleteSession(String id, String userId); - List listMyHistory(String sessionId); + List listHistory(String sessionId, String userId); } diff --git a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java index e5faaff..c9f3f2e 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatMessageServiceImpl.java @@ -45,12 +45,12 @@ public class ChatMessageServiceImpl implements ChatMessageService { @Override - public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { + public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) { // 目前仅使用讯飞星火大模型 String appId = chatCompletionReqDto.getAppId(); SseEmitter sseEmitter; if (StringUtils.isEmpty(appId)) { - sseEmitter = completionBySession(chatCompletionReqDto); + sseEmitter = completionBySession(chatCompletionReqDto, userId); } else { sseEmitter = completionByApp(chatCompletionReqDto, false); } @@ -114,10 +114,7 @@ public class ChatMessageServiceImpl implements ChatMessageService { /** * 会话模式 */ - private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto) { -// User user = SecurityUtil.getLoginUser(); -// String userId = user.getId(); - String userId = "1595321354"; + private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) { String userPrompt = chatCompletionReqDto.getText(); String sessionId = chatCompletionReqDto.getSessionId(); diff --git a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java index b5bfb30..cd37643 100644 --- a/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java +++ b/src/main/java/cn/teammodel/service/impl/ChatSessionServiceImpl.java @@ -33,12 +33,9 @@ public class ChatSessionServiceImpl implements ChatSessionService { private ChatSessionRepository chatSessionRepository; @Override - public String createSession() { -// todo User user = SecurityUtil.getLoginUser(); || 2. user.getName() -// String userId = user.getId(); - String userId = "1595321354"; + public String createSession(String userId, String name) { // 初始化欢迎语 - Message message = Message.of("", "你好" + "罗老师" + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); + Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); List history = Collections.singletonList(message); ChatSession chatSession = new ChatSession(); chatSession.setId(UUID.randomUUID().toString()); @@ -52,9 +49,8 @@ public class ChatSessionServiceImpl implements ChatSessionService { } @Override - public List listMySession() { -// String userId = SecurityUtil.getUserId(); - String userId = "1595321354"; + public List listMySession(String userId) { + List sessions = chatSessionRepository.findByUserId(userId); // 按更新时间排序 if (ObjectUtils.isNotEmpty(sessions)) { @@ -64,12 +60,9 @@ public class ChatSessionServiceImpl implements ChatSessionService { } @Override - public ChatSession updateSession(UpdateSessionDto updateSessionDto) { + public ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId) { String id = updateSessionDto.getId(); String title = updateSessionDto.getTitle(); -// User user = SecurityUtil.getLoginUser(); -// String userId = user.getId(); - String userId = "1595321354"; ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), ""); if (!session.getUserId().equals(userId)) { @@ -82,10 +75,7 @@ public class ChatSessionServiceImpl implements ChatSessionService { } @Override - public void deleteSession(String id) { -// User user = SecurityUtil.getLoginUser(); -// String userId = user.getId(); - String userId = "1595321354"; + public void deleteSession(String id, String userId) { ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在"); // 鉴权 if (!session.getUserId().equals(userId)) { @@ -95,9 +85,7 @@ public class ChatSessionServiceImpl implements ChatSessionService { } @Override - public List listMyHistory(String sessionId) { -// User user = SecurityUtil.getLoginUser(); - String userId = "1595321354"; + public List listHistory(String sessionId, String userId) { ChatSession session = chatSessionRepository.findChatSessionByIdAndCode(sessionId, PK.CHAT_SESSION); if (!userId.equals(session.getUserId())) { throw new ServiceException(ErrorCode.NO_AUTH_ERROR);