由于在与大模型中对话中会产生大量的临时训练数据,并且客户端和移动端需要双向绑定接受消息,所以采用Redis进行缓存服务,使用websocket进行客户端与移动端的通讯。本周实现了使用redis进行缓存临时训练数据,使用websocket进行客户端与服务端的通信,并将最后训练结果与历史会话转存到MySQL中。
Redis是一个完全开源免费的高性能(NOSQL)的key-value数据库。它遵守BSD协议,使用ANSI C语言编写,并支持网络和持久化。Redis拥有极高的性能,每秒可以进行11万次的读取操作和8.1万次的写入操作。它支持丰富的数据类型,包括String、Hash、List、Set和Ordered Set,并且所有的操作都是原子性的。此外,Redis还提供了多种特性,如发布/订阅、通知、key过期等。Redis采用自己实现的分离器来实现高速的读写操作,效率非常高。Redis是一个简单、高效、分布式、基于内存的缓存工具,通过网络连接提供Key-Value式的缓存服务。
当客户端和移动端建立连接后,用户可与大模型进行对话,用户发送会话内容由后端接受。服务端处理客户端数据,拆分数据。向客户端发送ack。同时由大模型根据用户的dialog生成对应的corretion,根据序号验证房间中是否已存在该序号的用户会话。若不存在:放入redis中的房间,将会话正文(并结合本房间历史会话)发给大模型,等待大模型回应。收到回应后,根据redis房间中的历史会话,按照AI会话标记序号,放入房间,随后发送给客户端。加入了超时重传机制,当服务端向用户端发送内容后,计时器开始计时,若规定时间内没有收到ack,则重传。会话结束后,将Redis中房间存储拿出修整数据结构,统计训练数据(训练时间、分数),存入数据库,销毁redis房间。
流程如下:
redis增删改查实现:
package com.clankalliance.backbeta.utils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class RedisUtils {
/**
* 数据缓存至redis
*
* @param key
* @param value
* @return
*/
public static <V> void add(String key, V value, StringRedisTemplate targetMap) {
try {
if(value != null)
targetMap.opsForValue().set(key, JSON.toJSONString(value));
} catch (Exception e) {
throw new RuntimeException("数据缓存至redis失败");
}
}
/**
* 从redis中获取缓存数据,转成对象
*
* @param key must not be {@literal null}.
* @param clazz 对象类型
* @return
*/
public static <V> V getObject(String key, StringRedisTemplate targetMap, Class<V> clazz) {
String value = get(key, targetMap);
V result = null;
if (value != null && !value.equals("")) {
result = JSONObject.parseObject(value, clazz);
}
return result;
}
/**
* 从redis中获取缓存数据,转成list
*
* @param key must not be {@literal null}.
* @param clazz 对象类型
* @return
*/
public static <V> List<V> getList(String key, StringRedisTemplate targetMap, Class<V> clazz) {
String value = get(key, targetMap);
List<V> result = Collections.emptyList();
if (value != null && !value.equals("")) {
result = JSONArray.parseArray(value, clazz);
}
return result;
}
/**
* 功能描述:Get the value of {@code key}.
*
* @param key must not be {@literal null}.
* @return java.lang.String
* @date 2021/9/19
**/
public static String get(String key, StringRedisTemplate targetMap) {
String value;
try {
value = targetMap.opsForValue().get(key);
} catch (Exception e) {
throw new RuntimeException("从redis缓存中获取缓存数据失败");
}
return value;
}
/**
* 删除key
*/
public static void delete(String key, StringRedisTemplate targetMap) {
targetMap.delete(key);
}
/**
* 批量删除key
*/
public static void delete(Collection<String> keys, StringRedisTemplate targetMap) {
targetMap.delete(keys);
}
/**
* 是否存在key
*/
public static Boolean hasKey(String key, StringRedisTemplate targetMap) {
return targetMap.hasKey(key);
}
/**
* 作为set使用 添加字符串
* @param setName
* @param targetMap
* @param value
*/
public static void setAdd(String setName, RedisTemplate<String,String> targetMap, String value){
targetMap.opsForSet().add(setName, value);
}
/**
* 返回set大小
* @param setName
* @param targetMap
* @return
*/
public static Long getSetSize(String setName, RedisTemplate<String,String> targetMap){
return targetMap.opsForSet().size(setName);
}
/**
* set中随机弹出
* @param setName
* @param targetMap
* @return
*/
public static String setPop(String setName, RedisTemplate<String,String> targetMap){
return targetMap.opsForSet().pop(setName);
}
/**
* set中删除指定元素
* @param setName
* @param targetMap
* @param value
*/
public static void setDel(String setName, RedisTemplate<String,String> targetMap, String value){
targetMap.opsForSet().remove(setName, value);
}
/**
* set是否存在指定元素
* @param setName
* @param targetMap
* @param value
* @return
*/
public static boolean setContains(String setName, RedisTemplate<String,String> targetMap, String value){
return Objects.equals(targetMap.opsForSet().isMember(setName, value), Boolean.TRUE);
}
}
websocket初始版本搭建:
@ -0,0 +1,149 @@
package com.clankalliance.backbeta.controller;
import com.clankalliance.backbeta.entity.User;
import com.clankalliance.backbeta.response.CommonResponse;
import com.clankalliance.backbeta.service.UserService;
import com.clankalliance.backbeta.utils.RedisUtils;
import com.clankalliance.backbeta.utils.TokenUtil;
import com.clankalliance.backbeta.utils.Websocket.SocketDomain;
import io.netty.util.internal.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.Map;
import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;
@Component
@ServerEndpoint("/websocket/{token}")
public class WebSocketServer {
private static TokenUtil tokenUtil;
/**
* WebSocket为多对象的
* @serverEndpoint下的@Resource注解会失效
* 故采用set方法注入
*/
@Resource
public void setTokenUtil(TokenUtil tokenUtil){WebSocketServer.tokenUtil = tokenUtil;}
// Redis的一个引入范例
// /**
// * key: id
// * value: roomCode
// */
// private static StringRedisTemplate RedisTemplateIdRoomCode;
//
// @Resource
// public void setRedisTemplateIdRoomCode(StringRedisTemplate redisTemplateIdRoomCode){WebSocketServer.RedisTemplateIdRoomCode = redisTemplateIdRoomCode;}
private static UserService userService;
@Resource
public void setUserService(UserService userService){WebSocketServer.userService = userService;}
private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);
//在线客户端数目
private static int onlineCount = 0;
//Map用于存储已连接的客户端信息(打算用redis改进)
private static ConcurrentHashMap<String, SocketDomain> websocketMap = new ConcurrentHashMap<>();
private Session session;
private String userId = "";
@OnOpen
public void onOpen(Session session, @PathParam("token") String token){
//TODO: 用户连接上Websocket客户端后,会调用该函数
CommonResponse response = tokenUtil.tokenCheck(token);
this.session = session;
if(!response.getSuccess()){
sendMessage("loginFail");
return;
}
String targetId = response.getMessage();
if(!websocketMap.containsKey(targetId)){
onlineCount ++;
}
this.userId = targetId;
SocketDomain socketDomain = new SocketDomain();
socketDomain.setSession(session);
socketDomain.setUri(session.getRequestURI().toString());
websocketMap.put(userId, socketDomain);
logger.info("id为" + userId + "的用户连接,当前人数为" + onlineCount);
}
@OnClose
public void onClose(){
if(websocketMap.containsKey(userId)){
websocketMap.remove(userId);
onlineCount --;
logger.info("id为" + userId + "的用户断开连接,当前人数为" + onlineCount);
}
//TODO: 用户断开连接Websocket客户端后,会调用该函数
}
@OnMessage
public void onMessage(String message, Session session){
//TODO: 用户向Websocket客户端发送消息,会调用该函数
if(!StringUtil.isNullOrEmpty(message)){
logger.info("收到id为" + userId + "的用户发来消息:" + message);
}
}
private void sendMessage(String obj){
synchronized (session){
this.session.getAsyncRemote().sendText(obj);
}
}
private void sendMessageTo(String userId,String obj){
SocketDomain socketDomain = websocketMap.get(userId);
try {
if(socketDomain !=null){
socketDomain.getSession().getAsyncRemote().sendText(obj);
}
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
private void sendMessageToAllExpectSelf(String message, Session fromSession) {
for(Map.Entry<String, SocketDomain> client : websocketMap.entrySet()){
Session toSession = client.getValue().getSession();
if( !toSession.getId().equals(fromSession.getId())&&toSession.isOpen()){
toSession.getAsyncRemote().sendText(message);
logger.info("服务端发送消息给"+client.getKey()+":"+message);
}
}
}
private void sendMessageToAll(String message){
for(Map.Entry<String, SocketDomain> client : websocketMap.entrySet()){
Session toSeesion = client.getValue().getSession();
if(toSeesion.isOpen()){
toSeesion.getAsyncRemote().sendText(message);
logger.info("服务端发送消息给"+client.getKey()+":"+message);
}
}
}
//给外部调用的方法接口
public void sendAll(String Message){
sendMessageToAll(Message);
}
}