涉及的HTTP工具类,见https://blog.csdn.net/qq_37686944/article/details/136424921
package com.station.gpt;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import com.station.utils.HttpUtil;
import lombok.Data;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.Scanner;
/**
* Yi-34B-Chat调用demo
*
* 百度智能云千帆大模型平台:Yi-34B-Chat。目前在线服务免费使用,其它模型根据token按量计费
*
* 1、模型使用说明:
* https://console.bce.baidu.com/qianfan/ais/console/onlineService
*
* 2、AKSK,即生成签名token时的CLIENT_ID和CLIENT_SECRET:
* 创建应用后,API Key和Secret Key分别对应CLIENT_ID和CLIENT_SECRET
* https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
*
*/
@Slf4j
public class WenXinChat {
private static final ThreadLocal<ChatContext> CHAT_CONTEXT_THREAD_LOCAL = new ThreadLocal<>();
private static final String CLIENT_ID = "见上方文档说明";
private static final String CLIENT_SECRET = "见上方文档说明";
public String addChatMessage(String message) {
ChatContext chatContext = CHAT_CONTEXT_THREAD_LOCAL.get();
chatContext = Optional.ofNullable(chatContext).orElseGet(ChatContext::new);
String token = chatContext.getToken();
token = Optional.ofNullable(token).orElse(calToken());
List<ChatMessage> chatMessageList = chatContext.getChatMessageList();
chatMessageList = Optional.ofNullable(chatMessageList).orElseGet(Lists::newArrayList);
ChatMessage chatMessage = new ChatMessage().setRole(ChatRole.USER.getRoleName()).setContent(message);
chatMessageList.add(chatMessage);
chatContext.setToken(token);
chatContext.setChatMessageList(chatMessageList);
setContext(chatContext);
return getChatResponse(chatContext);
}
private void setContext(ChatContext chatContext) {
CHAT_CONTEXT_THREAD_LOCAL.set(chatContext);
}
private void removeContext() {
CHAT_CONTEXT_THREAD_LOCAL.remove();
}
private String calToken() {
String response = HttpUtil.postJson("https://aip.baidubce.com/oauth/2.0/token",
null,
new HashMap<String, Object>() {{
put("grant_type", "client_credentials");
put("client_id", CLIENT_ID);
put("client_secret", CLIENT_SECRET);
}},
null);
return Optional.ofNullable(response)
.map(JSON::parseObject)
.map(e -> e.getString("access_token"))
.orElseThrow(() -> new RuntimeException("获取token失败"));
}
private String getChatResponse(ChatContext chatContext) {
String response = HttpUtil.postJson("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat",
null,
new HashMap<String, Object>() {{
put("access_token", chatContext.getToken());
}},
new JSONObject().fluentPut("messages", JSON.toJSON(chatContext.getChatMessageList())));
return Optional.ofNullable(response)
.map(JSON::parseObject)
.map(e -> e.getString("result"))
.map(chatResponse -> {
CHAT_CONTEXT_THREAD_LOCAL
.get()
.getChatMessageList()
.add(new ChatMessage().setRole(ChatRole.ASSISTANT.getRoleName()).setContent(chatResponse));
return chatResponse;
})
.orElseThrow(() -> new RuntimeException("获取chat结果失败"));
}
public static void main(String[] args) {
WenXinChat wenXinChat = new WenXinChat();
Scanner scanner = new Scanner(System.in);
while (scanner.hasNextLine()) {
String sendMessage = scanner.nextLine();
String receiveMessage = null;
try {
receiveMessage = wenXinChat.addChatMessage(sendMessage);
log.info("助手回答:{}", receiveMessage);
} catch (Exception e) {
log.error("对话异常:{}" + e.getMessage());
wenXinChat.removeContext();
}
}
wenXinChat.removeContext();
}
@Data
static class ChatContext {
private String token;
private List<ChatMessage> chatMessageList;
}
@Data
@Accessors(chain = true)
static class ChatMessage {
private String role;
private String content;
}
enum ChatRole {
USER("user"),
ASSISTANT("assistant");
private String roleName;
ChatRole(String roleName) {
this.roleName = roleName;
}
public String getRoleName() {
return roleName;
}
}
}