1. 背景
需要对接某个大模型的websocket对话接口,因为该项目中还需要对接另外的大模型服务但是另一个请求方式为http,为了实现模型服务切换以及对客户端做统一的流式响应,需要对大模型服务的调用做统一的流式响应封装。数据统一通过callback接收。
2. 代码
2.1 pom
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.5.3</version>
</dependency>
2.2 回调函数
@FunctionalInterface
public interface LLMChatCallback {
void callback(String message);
}
2.3 统一对话接口
public interface LLMService {
/**
* 模型对话
*
* @param payload 用于请求大模型的参数
* @param callback 响应回调 调用者通过实现callback函数接收消息
**/
void dialogue(String payload, LLMChatCallback callback);
}
2.4 自定义WebSocketClient
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.enums.ReadyState;
import org.java_websocket.handshake.ServerHandshake;
import java.net.URI;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 大模型websocket客户端
**/
@Slf4j
public class LLMWebSocketClient extends WebSocketClient {
/**
* 消息回调接口
*/
private LLMChatCallback callback = null;
/**
* 是否收到消息
*/
public AtomicBoolean hasMessage = new AtomicBoolean(false);
/**
* 等待连接打开的最长时间 ms
**/
private Long connectionOpenTimeout;
/**
* 等待服务器响应的最长时间 ms
**/
private Long connectionTimeout;
/**
* 等待数据被读取的最长时间 ms
**/
private Long readTimeout;
public LLMWebSocketClient(URI serverUri, Long connectionOpenTimeout, Long connectionTimeout, Long readTimeout) {
super(serverUri);
this.connectionOpenTimeout = connectionOpenTimeout;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;
log.info("SparkWebSocketClient init:{}, connectionTimeout {}ms, readTimeout {}ms",
serverUri, connectionTimeout, readTimeout);
}
/**
* 发送消息并同步获取响应
*
* @date 2024/7/11
* @param payload 发送的数据
* @param callback 响应回调,由调用者实现获取响应数据
* @return void
**/
public void send(String payload, LLMChatCallback callback) {
hasMessage.set(false);
// websocket 建立连接为异步,等待连接完成才能正常调用send方法
waitConnect();
// 设置回调
this.callback = callback;
super.send(payload);
// 等待响应
waitResponse();
// 等待消息完全返回
waitFinished();
}
/**
* websocket 建立连接为异步,等待连接完成才能正常调用send方法
**/
private void waitConnect() {
long startTime = System.currentTimeMillis();
long spendTime = 0L;
while (spendTime < connectionOpenTimeout) {
// 等待连接打开
if (this.getReadyState().equals(ReadyState.OPEN)) {
log.debug("connect open spend time: {}ms", spendTime);
return;
}
spendTime = System.currentTimeMillis() - startTime;
}
throw new RuntimeException(String.format("wait connect timeout %sms", connectionOpenTimeout));
}
/**
* 等待响应
**/
private void waitResponse() {
long startTime = System.currentTimeMillis();
long spendTime = 0L;
// 等待响应
while (spendTime < connectionTimeout) {
if (hasMessage.get()) {
return;
}
spendTime = System.currentTimeMillis() - startTime;
}
throw new RuntimeException(String.format("send timeout %sms", connectionTimeout));
}
/**
* 等待消息完全返回
**/
private void waitFinished() {
long startTime = System.currentTimeMillis();
long spendTime = 0L;
while (spendTime < readTimeout) {
if (this.getReadyState().equals(ReadyState.CLOSED)) {
log.debug("read spend time: {}ms", spendTime);
return;
}
spendTime = System.currentTimeMillis() - startTime;
}
throw new RuntimeException(String.format("read timeout %sms", readTimeout));
}
@Override
public void onOpen(ServerHandshake handshake) {
log.info("WebSocketClient onOpen: {}", handshake.getHttpStatusMessage());
}
@Override
public void onMessage(String message) {
log.info("WebSocketClient onMessage: {}", message);
hasMessage.set(true);
if (callback != null) {
// 可做业务处理 ...
callback.callback(message);
}
}
@Override
public void onClose(int i, String s, boolean b) {
this.hasMessage.set(false);
log.info("WebSocketClient onClose:{}", s);
}
@Override
public void onError(Exception e) {
log.error("WebSocketClient onError", e);
}
}
2.5 大模型Service实现
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.net.URI;
/**
* A大模型服务
**/
@Slf4j
@Service
public class ALMService implements LLMService {
@Override
public void dialogue(String payload, LLMChatCallback callback) {
LLMWebSocketClient webSocketClient = new LLMWebSocketClient(URI.create("http://xxxxxxxxxxxx"), 500L, 1000L, 3 * 60 * 1000L);
webSocketClient.connect();
webSocketClient.send(payload, callback);
webSocketClient.close();
}
}
2.6 使用
@Resource
private ALMService aLMService;
public void void dialogue() {
// 构建请求负载
String payload = "this is payload";
// 可以扩展使用工厂及策略模式获取对应的模型对话service
aLMService.dialogue(payload, message -> {
// 处理业务接收返回的信息 比如通过grpc的流式响应 onNext(message) 同步返回
System.out.println(message);
});
}