fix: 修复 api

11111
winter 9 months ago
parent acf448e980
commit a53efd1f7a

@ -3,13 +3,16 @@ package cn.teammodel.controller.frontend;
import cn.teammodel.common.IdRequest;
import cn.teammodel.common.R;
import cn.teammodel.model.dto.ai.*;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.model.entity.ai.ChatApp;
import cn.teammodel.model.entity.ai.ChatSession;
import cn.teammodel.security.utils.SecurityUtil;
import cn.teammodel.service.ChatAppService;
import cn.teammodel.service.ChatMessageService;
import cn.teammodel.service.ChatSessionService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -33,7 +36,8 @@ public class AiController {
@PostMapping("chat/completion")
@ApiOperation("与 spark 的流式对话")
public SseEmitter chatCompletion(@RequestBody @Valid ChatCompletionReqDto chatCompletionReqDto) {
return chatMessageService.chatCompletion(chatCompletionReqDto);
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
return chatMessageService.chatCompletion(chatCompletionReqDto, userId);
}
// @PostMapping("chat/test/completion")
@ -73,33 +77,40 @@ public class AiController {
@GetMapping("session/my")
@ApiOperation("查询我的聊天会话")
public R<List<ChatSession>> listMySession() {
List<ChatSession> sessions = chatSessionService.listMySession();
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
List<ChatSession> sessions = chatSessionService.listMySession(userId);
return R.success(sessions);
}
@GetMapping("chat/history/{sessionId}")
@ApiOperation("查询我的聊天记录")
public R<List<ChatSession.Message>> getHistory(@PathVariable String sessionId) {
List<ChatSession.Message> history = chatSessionService.listMyHistory(sessionId);
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
List<ChatSession.Message> history = chatSessionService.listHistory(sessionId, userId);
return R.success(history);
}
@PostMapping("session/create")
@ApiOperation("创建聊天会话")
public R<String> createSession() {
String sessionId = chatSessionService.createSession();
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
String name = (String) ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().get("name");
name = StringUtils.isBlank(name) ? "老师" : name;
String sessionId = chatSessionService.createSession(userId, name);
return R.success(sessionId);
}
@PostMapping("session/remove")
@ApiOperation("删除聊天会话")
public R<String> removeSession(@RequestBody @Valid IdRequest idRequest) {
chatSessionService.deleteSession(idRequest.getId());
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
chatSessionService.deleteSession(idRequest.getId(), userId);
return R.success("删除会话成功");
}
@PostMapping("session/update")
@ApiOperation("更新聊天会话")
public R<ChatSession> updateSession(@RequestBody @Valid UpdateSessionDto updateSessionDto) {
ChatSession session = chatSessionService.updateSession(updateSessionDto);
String userId = ((TmdUserDetail) SecurityUtil.getAuthentication().getPrincipal()).getClaims().getSubject();
ChatSession session = chatSessionService.updateSession(updateSessionDto, userId);
return R.success(session);
}

@ -1,5 +1,6 @@
package cn.teammodel.model.entity;
import io.jsonwebtoken.Claims;
import lombok.Data;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
@ -14,6 +15,8 @@ import java.util.Collection;
@Data
public class TmdUserDetail implements UserDetails {
private User user;
private Claims claims;
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return null;

@ -31,10 +31,6 @@ public class SecurityConfiguration {
private RestAccessDeniedHandler restAccessDeniedHandler;
@Resource
private RestAuthenticationEntryPoint restAuthenticationEntryPoint;
@Resource
private AuthInnerTokenFilter authInnerTokenFilter;
@Resource
private ApiAuthTokenFilter apiAuthTokenFilter;
@Bean
@Order(2)
@ -65,7 +61,7 @@ public class SecurityConfiguration {
.anyRequest().authenticated()
)
.oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) // 启用 OIDC jwt filter
.addFilterAfter(authInnerTokenFilter, BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter
.addFilterAfter(new AuthInnerTokenFilter(), BearerTokenAuthenticationFilter.class) // 添加 x-auth-authToken filter
.exceptionHandling()
.authenticationEntryPoint(restAuthenticationEntryPoint)
.accessDeniedHandler(restAccessDeniedHandler);
@ -74,13 +70,25 @@ public class SecurityConfiguration {
@Bean
@Order(1)
public SecurityFilterChain outterApiFilterChain(HttpSecurity http) throws Exception {
http.
antMatcher("/ai/api/**")
http
// CSRF禁用因为不使用session
.csrf().disable()
.cors().configurationSource(corsConfigurationSource())
.and()
// 禁用HTTP响应标头
.headers().cacheControl().disable()
.and()
.sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
.and()
.antMatcher("/ai/api/**")
.authorizeRequests(authorizeRequests ->
authorizeRequests
.anyRequest().authenticated()
)
.addFilterAfter(apiAuthTokenFilter, BearerTokenAuthenticationFilter.class);
.addFilterAfter(new ApiAuthTokenFilter(), BearerTokenAuthenticationFilter.class)
.exceptionHandling()
.authenticationEntryPoint(restAuthenticationEntryPoint)
.accessDeniedHandler(restAccessDeniedHandler);
return http.build();
}

@ -1,17 +1,15 @@
package cn.teammodel.security.filter;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.security.utils.JwtTokenUtil;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.annotation.Resource;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
@ -23,31 +21,30 @@ import java.io.IOException;
* @author winter
* @create 2023-11-09 10:43
*/
@Component
@Slf4j
public class ApiAuthTokenFilter extends OncePerRequestFilter {
@Resource
private JwtTokenUtil jwtTokenUtil;
JwtTokenUtil jwtTokenUtil = new JwtTokenUtil();
// todo: 修改 context 的值 + 写一下多过滤器链的复盘
@Override
protected void doFilterInternal(HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull FilterChain filterChain) throws ServletException, IOException {
SecurityContext context = SecurityContextHolder.getContext();
// 进入此过滤器说明 OIDC 认证成功,则验证 authToken
// 验证 authToken 合法
String token = request.getHeader("token");
if (StringUtils.isBlank(token)) {
filterChain.doFilter(request, response);
return;
}
Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000);
if (claims == null) {
// Claims claims = jwtTokenUtil.validAndGetClaims(token, "fXO6ko/qyXeYrkecPeKdgXnuLXf9vMEtnBC9OB3s+aA=", 315360000);
TmdUserDetail validUserDetail1 = jwtTokenUtil.getOutterTokenDetail(request);
if (validUserDetail1 == null) {
SecurityContextHolder.clearContext(); // 验证失败不应该在此处抛出异常,应该维护好 context 的值,以便整个过滤器链正常运行
filterChain.doFilter(request, response);
return;
}
// 组装 authToken 的 jwt 进 authentication
UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(claims, null, null);
UsernamePasswordAuthenticationToken finalAuthentication = new UsernamePasswordAuthenticationToken(validUserDetail1, null, null);
context.setAuthentication(finalAuthentication);
filterChain.doFilter(request, response);
}

@ -3,13 +3,11 @@ package cn.teammodel.security.filter;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.security.utils.JwtTokenUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
@ -24,12 +22,9 @@ import java.util.Collection;
* @author winter
* @create 2023-11-09 10:43
*/
@Component
@Slf4j
public class AuthInnerTokenFilter extends OncePerRequestFilter {
@Autowired
JwtTokenUtil jwtTokenUtil;
JwtTokenUtil jwtTokenUtil = new JwtTokenUtil();
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

@ -1,15 +1,14 @@
package cn.teammodel.security.utils;
import cn.teammodel.model.entity.User;
import cn.hutool.extra.spring.SpringUtil;
import cn.teammodel.model.entity.TmdUserDetail;
import cn.teammodel.model.entity.User;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import javax.crypto.spec.SecretKeySpec;
import javax.servlet.http.HttpServletRequest;
@ -22,13 +21,15 @@ import java.util.stream.Collectors;
* @date 20221124 10:50
* @description jwt
*/
@Component
@Slf4j
public class JwtTokenUtil {
private static final long NEVER_EXPIRE = 315360000 ; // 没有永不过期的api: 让时钟偏移十年
@Value("${jwt.secret}")
private String secret;
private final String secret;
{
secret = SpringUtil.getProperty("jwt.secret");
}
/**
* token
@ -159,6 +160,22 @@ public class JwtTokenUtil {
return tmdUserDetail;
}
public TmdUserDetail getOutterTokenDetail(HttpServletRequest request) {
String token = request.getHeader("token");
if (StringUtils.isBlank(token)) {
return null;
}
Claims claims = getClaimsFromToken(token);
if (claims == null) {
return null;
}
// 组装 TmdUserDetail
TmdUserDetail tmdUserDetail = new TmdUserDetail();
tmdUserDetail.setClaims(claims);
return tmdUserDetail;
}
@SuppressWarnings("unchecked")
private Set<String> convertToArray(Object o) {
if (o == null) {

@ -49,8 +49,7 @@ public class SecurityUtil
/**
*
**/
public static User getLoginUser()
{
public static User getLoginUser() {
try
{
return ((TmdUserDetail) getAuthentication().getPrincipal()).getUser();

@ -11,5 +11,5 @@ public interface ChatMessageService {
/**
* AI
*/
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto);
SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId);
}

@ -11,13 +11,13 @@ import java.util.List;
*/
public interface ChatSessionService {
String createSession();
String createSession(String userId, String name);
List<ChatSession> listMySession();
List<ChatSession> listMySession(String userId);
ChatSession updateSession(UpdateSessionDto updateSessionDto);
ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId);
void deleteSession(String id);
void deleteSession(String id, String userId);
List<ChatSession.Message> listMyHistory(String sessionId);
List<ChatSession.Message> listHistory(String sessionId, String userId);
}

@ -45,12 +45,12 @@ public class ChatMessageServiceImpl implements ChatMessageService {
@Override
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto) {
public SseEmitter chatCompletion(ChatCompletionReqDto chatCompletionReqDto, String userId) {
// 目前仅使用讯飞星火大模型
String appId = chatCompletionReqDto.getAppId();
SseEmitter sseEmitter;
if (StringUtils.isEmpty(appId)) {
sseEmitter = completionBySession(chatCompletionReqDto);
sseEmitter = completionBySession(chatCompletionReqDto, userId);
} else {
sseEmitter = completionByApp(chatCompletionReqDto, false);
}
@ -114,10 +114,7 @@ public class ChatMessageServiceImpl implements ChatMessageService {
/**
*
*/
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto) {
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
private SseEmitter completionBySession(ChatCompletionReqDto chatCompletionReqDto, String userId) {
String userPrompt = chatCompletionReqDto.getText();
String sessionId = chatCompletionReqDto.getSessionId();

@ -33,12 +33,9 @@ public class ChatSessionServiceImpl implements ChatSessionService {
private ChatSessionRepository chatSessionRepository;
@Override
public String createSession() {
// todo User user = SecurityUtil.getLoginUser(); || 2. user.getName()
// String userId = user.getId();
String userId = "1595321354";
public String createSession(String userId, String name) {
// 初始化欢迎语
Message message = Message.of("", "你好" + "罗老师" + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
Message message = Message.of("", "你好" + name + " ,我是你的私人 AI 助手小豆,你可以问我任何包括但不仅限于教育的问题,我会尽力为您解答!");
List<Message> history = Collections.singletonList(message);
ChatSession chatSession = new ChatSession();
chatSession.setId(UUID.randomUUID().toString());
@ -52,9 +49,8 @@ public class ChatSessionServiceImpl implements ChatSessionService {
}
@Override
public List<ChatSession> listMySession() {
// String userId = SecurityUtil.getUserId();
String userId = "1595321354";
public List<ChatSession> listMySession(String userId) {
List<ChatSession> sessions = chatSessionRepository.findByUserId(userId);
// 按更新时间排序
if (ObjectUtils.isNotEmpty(sessions)) {
@ -64,12 +60,9 @@ public class ChatSessionServiceImpl implements ChatSessionService {
}
@Override
public ChatSession updateSession(UpdateSessionDto updateSessionDto) {
public ChatSession updateSession(UpdateSessionDto updateSessionDto, String userId) {
String id = updateSessionDto.getId();
String title = updateSessionDto.getTitle();
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "");
if (!session.getUserId().equals(userId)) {
@ -82,10 +75,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
}
@Override
public void deleteSession(String id) {
// User user = SecurityUtil.getLoginUser();
// String userId = user.getId();
String userId = "1595321354";
public void deleteSession(String id, String userId) {
ChatSession session = RepositoryUtil.findOne(chatSessionRepository.findBySessionId(id), "该会话不存在");
// 鉴权
if (!session.getUserId().equals(userId)) {
@ -95,9 +85,7 @@ public class ChatSessionServiceImpl implements ChatSessionService {
}
@Override
public List<Message> listMyHistory(String sessionId) {
// User user = SecurityUtil.getLoginUser();
String userId = "1595321354";
public List<Message> listHistory(String sessionId, String userId) {
ChatSession session = chatSessionRepository.findChatSessionByIdAndCode(sessionId, PK.CHAT_SESSION);
if (!userId.equals(session.getUserId())) {
throw new ServiceException(ErrorCode.NO_AUTH_ERROR);

Loading…
Cancel
Save