You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
4.5 KiB
120 lines
4.5 KiB
package cn.teammodel.ai;
|
|
|
|
import cn.hutool.json.JSONUtil;
|
|
import cn.teammodel.ai.cache.HistoryCache;
|
|
import cn.teammodel.ai.domain.SparkChatRequestParam;
|
|
import cn.teammodel.ai.listener.SparkGptStreamListener;
|
|
import lombok.Data;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import okhttp3.HttpUrl;
|
|
import okhttp3.OkHttpClient;
|
|
import okhttp3.Request;
|
|
import org.springframework.beans.factory.InitializingBean;
|
|
import org.springframework.stereotype.Component;
|
|
|
|
import javax.annotation.Resource;
|
|
import javax.crypto.Mac;
|
|
import javax.crypto.spec.SecretKeySpec;
|
|
import java.net.URL;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.text.SimpleDateFormat;
|
|
import java.util.*;
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
/**
|
|
* spark Gpt client
|
|
* @author winter
|
|
* @create 2023-12-15 14:29
|
|
*/
|
|
@Component
|
|
@Data
|
|
@Slf4j
|
|
public class SparkGptClient implements InitializingBean {
|
|
@Resource
|
|
private SparkGptProperties sparkGptProperties;
|
|
private OkHttpClient okHttpClient;
|
|
private String authUrl;
|
|
|
|
/**
|
|
* 静态构造对象方法
|
|
*/
|
|
public void init() {
|
|
// 初始化缓存
|
|
HistoryCache.init(sparkGptProperties.getCache_timeout(), sparkGptProperties.getCache_context());
|
|
// 初始化 authUrl
|
|
authUrl = genAuthUrl(sparkGptProperties.getEndpoint(), sparkGptProperties.getApiKey(), sparkGptProperties.getApiSecret());
|
|
this.authUrl = authUrl.replace("http://", "ws://").replace("https://", "wss://");
|
|
log.info("[SPARK CHAT] 鉴权 endpoint : {}", this.authUrl);
|
|
// 初始化 okHttpClient
|
|
this.okHttpClient = new OkHttpClient()
|
|
.newBuilder()
|
|
.connectTimeout(90, TimeUnit.SECONDS)
|
|
.readTimeout(90, TimeUnit.SECONDS)
|
|
.writeTimeout(90, TimeUnit.SECONDS)
|
|
.build();
|
|
}
|
|
|
|
/**
|
|
* 流式的生成结果,以 sse 的方式向客户端推送
|
|
*/
|
|
public void streamChatCompletion(SparkChatRequestParam param, SparkGptStreamListener listener) {
|
|
try {
|
|
param.setAppId(sparkGptProperties.getAppId());
|
|
Request request = new Request.Builder().url(authUrl).build();
|
|
// 设置请求参数
|
|
listener.setRequestJson(param.toJsonParams());
|
|
log.info("[SPARK CHAT] 请求参数 {}", JSONUtil.parseObj(param.toJsonParams()).toStringPretty());
|
|
okHttpClient.newWebSocket(request, listener);
|
|
} catch (Exception e) {
|
|
log.error("[SPARK CHAT] Spark AI 请求异常: {}", e.getMessage());
|
|
e.printStackTrace();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 生成鉴权URL
|
|
*/
|
|
public static String genAuthUrl(String endpoint, String apiKey, String apiSecret) {
|
|
URL url = null;
|
|
String date = null;
|
|
String preStr = null;
|
|
Mac mac = null;
|
|
try {
|
|
url = new URL(endpoint);
|
|
// 时间
|
|
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
|
|
format.setTimeZone(TimeZone.getTimeZone("GMT"));
|
|
date = format.format(new Date());
|
|
// 拼接
|
|
preStr = "host: " + url.getHost() + "\n" +
|
|
"date: " + date + "\n" +
|
|
"GET " + url.getPath() + " HTTP/1.1";
|
|
// SHA256加密
|
|
mac = Mac.getInstance("hmacsha256");
|
|
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
|
|
mac.init(spec);
|
|
} catch (Exception e) {
|
|
log.error("[SPARK CHAT] 生成鉴权URL失败, endpoint: {}, apiKey: {}, apiSecret: {}", endpoint, apiKey, apiSecret);
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
|
|
// Base64加密
|
|
String sha = Base64.getEncoder().encodeToString(hexDigits);
|
|
// 拼接
|
|
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
|
|
// 拼接地址
|
|
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().
|
|
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).
|
|
addQueryParameter("date", date).//
|
|
addQueryParameter("host", url.getHost()).//
|
|
build();
|
|
return httpUrl.toString();
|
|
}
|
|
|
|
@Override
|
|
public void afterPropertiesSet() throws Exception {
|
|
init();
|
|
}
|
|
}
|