You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.5 KiB

package cn.teammodel.ai;
import cn.hutool.json.JSONUtil;
import cn.teammodel.ai.cache.HistoryCache;
import cn.teammodel.ai.domain.SparkChatRequestParam;
import cn.teammodel.ai.listener.SparkGptStreamListener;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* spark Gpt client
* @author winter
* @create 2023-12-15 14:29
*/
@Component
@Data
@Slf4j
public class SparkGptClient implements InitializingBean {
@Resource
private SparkGptProperties sparkGptProperties;
private OkHttpClient okHttpClient;
private String authUrl;
/**
* 静态构造对象方法
*/
public void init() {
// 初始化缓存
HistoryCache.init(sparkGptProperties.getCache_timeout(), sparkGptProperties.getCache_context());
// 初始化 authUrl
authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
log.info("[SPARK CHAT] 鉴权 endpoint : {}", this.authUrl);
// 初始化 okHttpClient
this.okHttpClient = new OkHttpClient()
.newBuilder()
.connectTimeout(90, TimeUnit.SECONDS)
.readTimeout(90, TimeUnit.SECONDS)
.writeTimeout(90, TimeUnit.SECONDS)
.build();
}
/**
* 流式的生成结果,以 sse 的方式向客户端推送
*/
public void streamChatCompletion(SparkChatRequestParam param, SparkGptStreamListener listener) {
try {
param.setAppId(sparkGptProperties.getAppId());
Request request = new Request.Builder().url(authUrl).build();
// 设置请求参数
listener.setRequestJson(param.toJsonParams());
log.info("[SPARK CHAT] 请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
okHttpClient.newWebSocket(request, listener);
} catch (Exception e) {
log.error("[SPARK CHAT] Spark AI 请求异常: {}", e.getMessage());
e.printStackTrace();
}
}
/**
* 生成鉴权URL
*/
public static String genAuthUrl(String endpoint, String apiKey, String apiSecret) {
URL url = null;
String date = null;
String preStr = null;
Mac mac = null;
try {
url = new URL(endpoint);
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
date = format.format(new Date());
// 拼接
preStr = "host: " + url.getHost() + "\n" +
"date: " + date + "\n" +
"GET " + url.getPath() + " HTTP/1.1";
// SHA256加密
mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
} catch (Exception e) {
log.error("[SPARK CHAT] 生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
throw new RuntimeException(e);
}
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// 拼接
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
// 拼接地址
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
return httpUrl.toString();
}
@Override
public void afterPropertiesSet() throws Exception {
init();
}
}