fix: 修复 api

11111
winter 9 months ago
parent acf448e980
commit a53efd1f7a

@ -3,13 +3,16 @@ package cn.teammodel.controller.frontend;
import cn.teammodel.common.IdRequest; import cn.teammodel.common.IdRequest;
import cn.teammodel.common.R; import cn.teammodel.common.R;
import cn.teammodel.model.dto.ai.*; import cn.teammodel.model.dto.ai.*;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.model.entity.ai.ChatApp; import cn.teammodel.model.entity.ai.ChatApp;
import cn.teammodel.model.entity.ai.ChatSession; import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatAppService; import cn.teammodel.service.ChatAppService;
import cn.teammodel.service.ChatMessageService; import cn.teammodel.service.ChatMessageService;
import cn.teammodel.service.ChatSessionService; import cn.teammodel.service.ChatSessionService;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -33,7 +36,8 @@ public class AiController {
@PostMapping("chat/completion") @PostMapping("chat/completion")
@ApiOperation("与 spark 的流式对话") @ApiOperation("与 spark 的流式对话")
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) { public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
return chatMessageService.chatCompletion(chatCompletionReqDto); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
return chatMessageService.chatCompletion(chatCompletionReqDto, userId);
} }
// @PostMapping("chat/test/completion") // @PostMapping("chat/test/completion")
@ -73,33 +77,40 @@ public class AiController {
@GetMapping("session/my") @GetMapping("session/my")
@ApiOperation("查询我的聊天会话") @ApiOperation("查询我的聊天会话")
public R<List<ChatSession>> listMySession() { public R<List<ChatSession>> listMySession() {
List<ChatSession> sessions = chatSessionService.listMySession(); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
List<ChatSession> sessions = chatSessionService.listMySession(userId);
return R.success(sessions); return R.success(sessions);
} }
@GetMapping("chat/history/{sessionId}") @GetMapping("chat/history/{sessionId}")
@ApiOperation("查询我的聊天记录") @ApiOperation("查询我的聊天记录")
public R<List<ChatSession.Message>> getHistory(@PathVariable String sessionId) { public R<List<ChatSession.Message>> getHistory(@PathVariable String sessionId) {
List<ChatSession.Message> history = chatSessionService.listMyHistory(sessionId); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
List<ChatSession.Message> history = chatSessionService.listHistory(sessionId, userId);
return R.success(history); return R.success(history);
} }
@PostMapping("session/create") @PostMapping("session/create")
@ApiOperation("创建聊天会话") @ApiOperation("创建聊天会话")
public R<String> createSession() { public R<String> createSession() {
String sessionId = chatSessionService.createSession(); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
String name = (String) ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().get("name");
name = StringUtils.isBlank(name) ? "老师" : name;
String sessionId = chatSessionService.createSession(userId, name);
return R.success(sessionId); return R.success(sessionId);
} }
@PostMapping("session/remove") @PostMapping("session/remove")
@ApiOperation("删除聊天会话") @ApiOperation("删除聊天会话")
public R<String> removeSession(@RequestBody @Valid IdRequest idRequest) { public R<String> removeSession(@RequestBody @Valid IdRequest idRequest) {
chatSessionService.deleteSession(idRequest.getId()); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
chatSessionService.deleteSession(idRequest.getId(), userId);
return R.success("删除会话成功"); return R.success("删除会话成功");
} }
@PostMapping("session/update") @PostMapping("session/update")
@ApiOperation("更新聊天会话") @ApiOperation("更新聊天会话")
public R<ChatSession> updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) { public R<ChatSession> updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) {
ChatSession session = chatSessionService.updateSession(updateSessionDto); String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
ChatSession session = chatSessionService.updateSession(updateSessionDto, userId);
return R.success(session); return R.success(session);
} }

@ -1,5 +1,6 @@
package cn.teammodel.model.entity; package cn.teammodel.model.entity;
import io.jsonwebtoken.Claims;
import lombok.Data; import lombok.Data;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
@ -14,6 +15,8 @@ import java.util.Collection;
@Data @Data
public class TmdUserDetail implements UserDetails { public class TmdUserDetail implements UserDetails {
private User user; private User user;
private Claims claims;
@Override @Override
public Collection<? extends GrantedAuthority> getAuthorities() { public Collection<? extends GrantedAuthority> getAuthorities() {
return null; return null;

@ -31,10 +31,6 @@ public class SecurityConfiguration {
private RestAccessDeniedHandler restAccessDeniedHandler; private RestAccessDeniedHandler restAccessDeniedHandler;
@Resource @Resource
private RestAuthenticationEntryPoint restAuthenticationEntryPoint; private RestAuthenticationEntryPoint restAuthenticationEntryPoint;
@Resource
private AuthInnerTokenFilter authInnerTokenFilter;
@Resource
private ApiAuthTokenFilter apiAuthTokenFilter;
@Bean @Bean
@Order(2) @Order(2)
@ -65,7 +61,7 @@ public class SecurityConfiguration {
.anyRequest().authenticated() .anyRequest().authenticated()
) )
.oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) // 启用 OIDC jwt filter .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) // 启用 OIDC jwt filter
.addFilterAfter(authInnerTokenFilter, BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter .addFilterAfter(new AuthInnerTokenFilter(), BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter
.exceptionHandling() .exceptionHandling()
.authenticationEntryPoint(restAuthenticationEntryPoint) .authenticationEntryPoint(restAuthenticationEntryPoint)
.accessDeniedHandler(restAccessDeniedHandler); .accessDeniedHandler(restAccessDeniedHandler);
@ -74,13 +70,25 @@ public class SecurityConfiguration {
@Bean @Bean
@Order(1) @Order(1)
public SecurityFilterChain outterApiFilterChain(HttpSecurity http) throws Exception { public SecurityFilterChain outterApiFilterChain(HttpSecurity http) throws Exception {
http. http
antMatcher("/ai/api/**") // CSRF禁用因为不使用session
.csrf().disable()
.cors().configurationSource(corsConfigurationSource())
.and()
// 禁用HTTP响应标头
.headers().cacheControl().disable()
.and()
.sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
.and()
.antMatcher("/ai/api/**")
.authorizeRequests(authorizeRequests -> .authorizeRequests(authorizeRequests ->
authorizeRequests authorizeRequests
.anyRequest().authenticated() .anyRequest().authenticated()
) )
.addFilterAfter(apiAuthTokenFilter, BearerTokenAuthenticationFilter.class); .addFilterAfter(new ApiAuthTokenFilter(), BearerTokenAuthenticationFilter.class)
.exceptionHandling()
.authenticationEntryPoint(restAuthenticationEntryPoint)
.accessDeniedHandler(restAccessDeniedHandler);
return http.build(); return http.build();
} }

@ -1,17 +1,15 @@
package cn.teammodel.security.filter; package cn.teammodel.security.filter;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.security.utils.JwtTokenUtil; import cn.teammodel.security.utils.JwtTokenUtil;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import javax.annotation.Resource;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -23,31 +21,30 @@ import java.io.IOException;
* @author winter * @author winter
* @create 2023-11-09 10:43 * @create 2023-11-09 10:43
*/ */
@Component
@Slf4j @Slf4j
public class ApiAuthTokenFilter extends OncePerRequestFilter { public class ApiAuthTokenFilter extends OncePerRequestFilter {
@Resource JwtTokenUtil jwtTokenUtil = new JwtTokenUtil();
private JwtTokenUtil jwtTokenUtil;
// todo: 修改 context 的值 + 写一下多过滤器链的复盘
@Override @Override
protected void doFilterInternal(HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException { protected void doFilterInternal(HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException {
SecurityContext context = SecurityContextHolder.getContext(); SecurityContext context = SecurityContextHolder.getContext();
// 进入此过滤器说明 OIDC 认证成功,则验证 authToken
// 验证 authToken 合法 // 验证 authToken 合法
String token = request.getHeader("token"); String token = request.getHeader("token");
if (StringUtils.isBlank(token)) { if (StringUtils.isBlank(token)) {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000); // Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000);
if (claims == null) { TmdUserDetail validUserDetail1 = jwtTokenUtil.getOutterTokenDetail(request);
if (validUserDetail1 == null) {
SecurityContextHolder.clearContext(); // 验证失败不应该在此处抛出异常,应该维护好 context 的值,以便整个过滤器链正常运行 SecurityContextHolder.clearContext(); // 验证失败不应该在此处抛出异常,应该维护好 context 的值,以便整个过滤器链正常运行
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
// 组装 authToken 的 jwt 进 authentication // 组装 authToken 的 jwt 进 authentication
UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(claims, null, null); UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(validUserDetail1, null, null);
context.setAuthentication(finalAuthentication); context.setAuthentication(finalAuthentication);
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }

@ -3,13 +3,11 @@ package cn.teammodel.security.filter;
import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.security.utils.JwtTokenUtil; import cn.teammodel.security.utils.JwtTokenUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -24,12 +22,9 @@ import java.util.Collection;
* @author winter * @author winter
* @create 2023-11-09 10:43 * @create 2023-11-09 10:43
*/ */
@Component
@Slf4j @Slf4j
public class AuthInnerTokenFilter extends OncePerRequestFilter { public class AuthInnerTokenFilter extends OncePerRequestFilter {
JwtTokenUtil jwtTokenUtil = new JwtTokenUtil();
@Autowired
JwtTokenUtil jwtTokenUtil;
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

@ -1,15 +1,14 @@
package cn.teammodel.security.utils; package cn.teammodel.security.utils;
import cn.teammodel.model.entity.User; import cn.hutool.extra.spring.SpringUtil;
import cn.teammodel.model.entity.TmdUserDetail; import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.model.entity.User;
import io.jsonwebtoken.Claims; import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts; import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm; import io.jsonwebtoken.SignatureAlgorithm;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.SecretKeySpec;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -22,13 +21,15 @@ import java.util.stream.Collectors;
* @date 20221124 10:50 * @date 20221124 10:50
* @description jwt * @description jwt
*/ */
@Component
@Slf4j @Slf4j
public class JwtTokenUtil { public class JwtTokenUtil {
private static final long NEVER_EXPIRE = 315360000; // 没有永不过期的api: 让时钟偏移十年 private static final long NEVER_EXPIRE = 315360000 ; // 没有永不过期的api: 让时钟偏移十年
@Value("${jwt.secret}") private final String secret;
private String secret;
{
secret = SpringUtil.getProperty("jwt.secret");
}
/** /**
* token * token
@ -159,6 +160,22 @@ public class JwtTokenUtil {
return tmdUserDetail; return tmdUserDetail;
} }
public TmdUserDetail getOutterTokenDetail(HttpServletRequest request) {
String token = request.getHeader("token");
if (StringUtils.isBlank(token)) {
return null;
}
Claims claims = getClaimsFromToken(token);
if (claims == null) {
return null;
}
// 组装 TmdUserDetail
TmdUserDetail tmdUserDetail = new TmdUserDetail();
tmdUserDetail.setClaims(claims);
return tmdUserDetail;
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private Set<String> convertToArray(Object o) { private Set<String> convertToArray(Object o) {
if (o == null) { if (o == null) {

@ -49,8 +49,7 @@ public class SecurityUtil
/** /**
* *
**/ **/
public static User getLoginUser() public static User getLoginUser() {
{
try try
{ {
return ((TmdUserDetail) getAuthentication().getPrincipal()).getUser(); return ((TmdUserDetail) getAuthentication().getPrincipal()).getUser();

@ -11,5 +11,5 @@ public interface ChatMessageService {
/** /**
* AI * AI
*/ */
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto); SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId);
} }

@ -11,13 +11,13 @@ import java.util.List;
*/ */
public interface ChatSessionService { public interface ChatSessionService {
String createSession(); String createSession(String userId, String name);
List<ChatSession> listMySession(); List<ChatSession> listMySession(String userId);
ChatSession updateSession(UpdateSessionDto updateSessionDto); ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId);
void deleteSession(String id); void deleteSession(String id, String userId);
List<ChatSession.Message> listMyHistory(String sessionId); List<ChatSession.Message> listHistory(String sessionId, String userId);
} }

@ -45,12 +45,12 @@ public class ChatMessageServiceImpl implements ChatMessageService {
@Override @Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) { public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// 目前仅使用讯飞星火大模型 // 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId(); String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter; SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) { if (StringUtils.isEmpty(appId)) {
sseEmitter = completionBySession(chatCompletionReqDto); sseEmitter = completionBySession(chatCompletionReqDto, userId);
} else { } else {
sseEmitter = completionByApp(chatCompletionReqDto, false); sseEmitter = completionByApp(chatCompletionReqDto, false);
} }
@ -114,10 +114,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
/** /**
* *
*/ */
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto) { private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
String userPrompt = chatCompletionReqDto.getText(); String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId(); String sessionId = chatCompletionReqDto.getSessionId();

@ -33,12 +33,9 @@ public class ChatSessionServiceImpl implements ChatSessionService {
private ChatSessionRepository chatSessionRepository; private ChatSessionRepository chatSessionRepository;
@Override @Override
public String createSession() { public String createSession(String userId, String name) {
// todo User user = SecurityUtil.getLoginUser(); || 2. user.getName()
// String userId = user.getId();
String userId = "1595321354";
// 初始化欢迎语 // 初始化欢迎语
Message message = Message.of("", "你好" + "罗老师" + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!"); Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<Message> history = Collections.singletonList(message); List<Message> history = Collections.singletonList(message);
ChatSession chatSession = new ChatSession(); ChatSession chatSession = new ChatSession();
chatSession.setId(UUID.randomUUID().toString()); chatSession.setId(UUID.randomUUID().toString());
@ -52,9 +49,8 @@ public class ChatSessionServiceImpl implements ChatSessionService {
} }
@Override @Override
public List<ChatSession> listMySession() { public List<ChatSession> listMySession(String userId) {
// String userId = SecurityUtil.getUserId();
String userId = "1595321354";
List<ChatSession> sessions = chatSessionRepository.findByUserId(userId); List<ChatSession> sessions = chatSessionRepository.findByUserId(userId);
// 按更新时间排序 // 按更新时间排序
if (ObjectUtils.isNotEmpty(sessions)) { if (ObjectUtils.isNotEmpty(sessions)) {
@ -64,12 +60,9 @@ public class ChatSessionServiceImpl implements ChatSessionService {
} }
@Override @Override
public ChatSession updateSession(UpdateSessionDto updateSessionDto) { public ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId) {
String id = updateSessionDto.getId(); String id = updateSessionDto.getId();
String title = updateSessionDto.getTitle(); String title = updateSessionDto.getTitle();
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), ""); ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "");
if (!session.getUserId().equals(userId)) { if (!session.getUserId().equals(userId)) {
@ -82,10 +75,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
} }
@Override @Override
public void deleteSession(String id) { public void deleteSession(String id, String userId) {
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在"); ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在");
// 鉴权 // 鉴权
if (!session.getUserId().equals(userId)) { if (!session.getUserId().equals(userId)) {
@ -95,9 +85,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
} }
@Override @Override
public List<Message> listMyHistory(String sessionId) { public List<Message> listHistory(String sessionId, String userId) {
// User user = SecurityUtil.getLoginUser();
String userId = "1595321354";
ChatSession session = chatSessionRepository.findChatSessionByIdAndCode(sessionId, PK.CHAT_SESSION); ChatSession session = chatSessionRepository.findChatSessionByIdAndCode(sessionId, PK.CHAT_SESSION);
if (!userId.equals(session.getUserId())) { if (!userId.equals(session.getUserId())) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR); throw new ServiceException(ErrorCode.NO_AUTH_ERROR);

Loading…
Cancel
Save