概要
最近写的一个小项目,涉及到AI问答的功能,所以接入了科大讯飞的Api,刚开始有个问题是:使用http请求,一次性完整的获得一个答案,可能得需要30多秒,这样用户体验十分不友好,所以后面决定使用流式数据实现这一部分功能。
用到的技术:SpringBoot、WebSocket等。
SpringBoot中如何使用WebSocket?
1)引入依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
<version>3.1.0</version>
</dependency>
2) 配置类
把serverEndpointExporter 交给Spring管理
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter()
{
return new ServerEndpointExporter();
}
}
3) 创建一个WebSocker服务类
下面是一个基本的服务类
@Component
@Slf4j
@ServerEndpoint("/ws/{userid}")
//这里类似于Controller的接口,不同于Controller接口,这里是ws://....进行访问的。
public class WebSocketServer {
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("userid") Integer userid) {
log.info("{} 与服务器进行连接.",userid);
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, @PathParam("userid") String userid) {
log.info("用户:{},发送信息:{}",userid,message);
}
/**
* 连接关闭调用的方法
*
* @param userid
*/
@OnClose
public void onClose(@PathParam("userid") Integer userid) {
log.info("{} :关闭连接" , userid);
}
/**
* 给用户发送信息
*
* @param message
*/
public void sendAIResultToUser(Integer userid,String message) {
}
}
4) 测试连接
WebSocket在线测试工具 (wstool.js.org)
实现AI问答的流式输出
接入星火API,使用的是一个github上写好的SDK
<!--讯飞SDK-->
<dependency>
<groupId>io.github.briqt</groupId>
<artifactId>xunfei-spark4j</artifactId>
<version>1.2.0</version>
</dependency>
这个SDK提供了两种模式:一种是基于http的,一种是socket的,如果实现流式输出必然得使用后者。那么我想到的思路大概是这样的:前端打开AI问答的页面,在此时就和服务器进行socket连接,后端通过唯一的userId作为key,session作为value,把连接信息放到Map中。
当用户写好问题时,进行一次http请求,也就是相当于把问题和用户信息给后端(其实这里可以直接通过socket实现对话的方式进行AI问答,由于我需要对用户使用AI情况进行统计、管理,所以加了层http请求), 通过请求拿到userId,进而拿到对应的session,这时用户一定是和服务器有socket连接的,后端服务器与星火的服务器进行socket连接,星火的服务器给后端服务器响应内容,后端服务器就把相应的内容响应给前端,前端进行展示即可;用户可以继续使用,也可以退出此功能,退出时,就会把对应的session删除。
WebSocketServer 代码
前端首先回到这里与后端服务器的进行socket连接。
@Component
@Slf4j
@ServerEndpoint("/ws/{userid}")
public class WebSocketServer {
//存放会话对象
private static Map<Integer, Session> sessionMap = new HashMap();
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("userid") Integer userid) {
log.info("{} 与服务器进行连接.",userid);
sessionMap.put(userid , session);
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, @PathParam("userid") String userid) {
log.info("用户:{},发送信息:{}",userid,message);
}
/**
* 连接关闭调用的方法
*
* @param userid
*/
@OnClose
public void onClose(@PathParam("userid") Integer userid) {
log.info("{} :关闭连接" , userid);
sessionMap.remove(userid);
}
/**
* 给用户发送信息
*
* @param message
*/
public void sendAIResultToUser(Integer userid,String message) {
//获得对应的session
Session session = sessionMap.get(userid);
try {
//服务器向客户端发送消息
session.getBasicRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* @description: 通过userId获得一个session
* @param: [java.lang.Integer]
* @return: javax.websocket.Session
*/
public Session getSessionByUserId(Integer userid){
return sessionMap.get(userid);
}
}
Controller层
调用/ai接口
@Resource
private WebSocketServer webSocketServer;
@PostMapping("/ai")
@ApiOperation("使用ai问答")
public ResultJson<String> useAI(@RequestBody AskContent askContent){
log.info("使用AI问答");
try{
// TODO:进行AI相关的统计、管理操作
//调用service层的方法
userService.sendMessageToXingHuo(askContent.getQuestion(),webSocketServer.getSessionByUserId(askContent.getUserid()));
return ResultJson.success(null);
}catch (Exception e){
redisUtil.incrby("ai:error",1);
System.err.println(e.getLocalizedMessage() + e.getMessage());
return ResultJson.error("AI机器人出了点错误,请稍后再试 ~");
}
}
Service层
@Override
public void sendMessageToXingHuo(String question, Session session) {
/*下面都是使用的SDK*/
List<SparkMessage> messages = new ArrayList<>(); //消息列表
//MessageConstant.PRECONDITION 是我定义的一个常量.
messages.add(SparkMessage.systemContent(MessageConstant.PRECONDITION)); //预设问题
messages.add(SparkMessage.userContent(question)); //设置问题
//发送信息
SparkRequest sparkRequest = SparkRequest.builder()
.messages(messages)
.maxTokens(1024) //回答的最大token
.temperature(0.5) //结果随机性
.apiVersion(SparkApiVersion.V3_5) //版本情况
.build(); //构建
//重新设置一个session(返回客户端)
sparkConsoleListener.setSession(session);
//封装聊天信息
sparkClient.chatStream(sparkRequest,sparkConsoleListener);
}
SparkConsoleListener
如果想要实现AI的流式输出,必须自己写一个类,实现SparkBaseListener,用于监听星火服务器给后端的响应,同时在这里后端把响应内容返回给前端。
public class SparkConsoleListener extends SparkBaseListener {
private Session session = null; //请求ai的会话
public void setSession(Session session){
this.session =session; //设置session
}
//固定的代码
public ObjectMapper objectMapper = new ObjectMapper();
{
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
}
@Override
public void onMessage(String content, SparkResponseUsage usage, Integer status, SparkRequest sparkRequest, SparkResponse sparkResponse, WebSocket webSocket) {
if (0 == status) {
List<SparkMessage> messages = sparkRequest.getPayload().getMessage().getText();
try {
System.out.println("提问:" + objectMapper.writeValueAsString(messages));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
try {
session.getBasicRemote().sendText(content); //将socket的信息返回给前端
} catch(IOException e){
System.err.println(e.getMessage());
}
if (2 == status) {
SparkTextUsage textUsage = usage.getText();
System.out.println("\n回答结束;提问tokens:" + textUsage.getPromptTokens()
+ ",回答tokens:" + textUsage.getCompletionTokens()
+ ",总消耗tokens:" + textUsage.getTotalTokens());
try {
//结束的时候发个 |
session.getBasicRemote().sendText("|");
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
小结
网上有许多简单、高效的实现方法,这里只是我个人对该问题,十分局限的见解,如有问题欢迎指正。