package cn.teammodel.ai; import cn.hutool.json.JSONUtil; import cn.teammodel.ai.domain.SparkChatRequestParam; import cn.teammodel.ai.listener.SparkGptStreamListener; import lombok.Data; import lombok.extern.slf4j.Slf4j; import okhttp3.HttpUrl; import okhttp3.OkHttpClient; import okhttp3.Request; import org.springframework.beans.factory.InitializingBean; import org.springframework.stereotype.Component; import javax.annotation.Resource; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.net.URL; import java.nio.charset.StandardCharsets; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.TimeUnit; /** * spark Gpt client * @author winter * @create 2023-12-15 14:29 */ @Component @Data @Slf4j public class SparkGptClient implements InitializingBean { @Resource private SparkGptProperties sparkGptProperties; private OkHttpClient okHttpClient; private String authUrl; /** * 静态构造对象方法 */ public void init() { authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret()); this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://"); log.info("[SPARK CHAT] 鉴权 url: {}", this.authUrl); this.okHttpClient = new OkHttpClient() .newBuilder() .connectTimeout(90, TimeUnit.SECONDS) .readTimeout(90, TimeUnit.SECONDS) .writeTimeout(90, TimeUnit.SECONDS) .build(); } /** * 流式的生成结果,以 sse 的方式向客户端推送 */ public void streamChatCompletion(SparkChatRequestParam param, SparkGptStreamListener listener) { try { param.setAppId(sparkGptProperties.getAppId()); Request request = new Request.Builder().url(authUrl).build(); // 设置请求参数 listener.setRequestJson(param.toJsonParams()); log.info("请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty()); okHttpClient.newWebSocket(request, listener); } catch (Exception e) { log.error("Spark AI 请求异常: {}", e.getMessage()); e.printStackTrace(); } } /** * 生成鉴权URL */ public static String genAuthUrl(String endpoint, String apiKey, String apiSecret) { URL url = null; String date = null; String preStr = null; Mac mac = null; try { url = new URL(endpoint); // 时间 SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); format.setTimeZone(TimeZone.getTimeZone("GMT")); date = format.format(new Date()); // 拼接 preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1"; // SHA256加密 mac = Mac.getInstance("hmacsha256"); SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256"); mac.init(spec); } catch (Exception e) { log.error("生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret); throw new RuntimeException(e); } byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8)); // Base64加密 String sha = Base64.getEncoder().encodeToString(hexDigits); // 拼接 String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha); // 拼接地址 HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder(). addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))). addQueryParameter("date", date).// addQueryParameter("host", url.getHost()).// build(); return httpUrl.toString(); } @Override public void afterPropertiesSet() throws Exception { init(); } }