本周工作主要在对websocket进行测试,发现之前队友实现的websocket通信方法有些逻辑问题,对该部分问题进行修复并重新实现了部分机制:
1 在原版本代码中,关于训练房间与对话消息的存储,是使用原有RedisUtil工具类中的方法,尝试直接将数据库Entity存入Redis,而工具类中的存入方法是通过将对象转为JSON字符串后再存入,直接尝试存入Entity时会出现错误。
//Dialog.java
@Entity
@Table()
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Dialog {
@Id
private String id;
private Date time;
@Lob
private String content;
//切断循环引用
@ManyToOne
@JsonIgnoreProperties("trainingDataList")
private User sender;
private String correction;
private Double score;
}
//RedisUtil.java
//...
/**
* 数据缓存至redis
*
* @param key
* @param value
* @return
*/
public static <V> void add(String key, V value, StringRedisTemplate targetMap) {
try {
if(value != null){
if(hasKey(key, targetMap)){
delete(key, targetMap);
}
String temp = JSON.toJSONString(value);
targetMap.opsForValue().set(key, temp);
}
} catch (Exception e) {
throw new RuntimeException("数据缓存至redis失败");
}
}
//...
解决方法: 另外组织了一个不涉及数据库联表关系的类,通过该类来组织训练数据,并在最终训练结束将数据存入MySQL时对数据结构再进行转换,重新映射关系,并存入。
//DialogDataBody.java
@Data
@NoArgsConstructor
@AllArgsConstructor
public class DialogDataBody {
private String id;
private Date time;
private Long senderId;
private String content;
private String correction;
private Double score;
public Dialog toDialog(User sender){
return new Dialog(id, time, content, sender, correction, score);
}
}
2 超时重传无法工作: 原版本代码对OnMessage事件处理逻辑有问题,尝试在收到消息并处理的这一次OnMessage中处理下一次收到消息时的逻辑。故重写该部分逻辑。
解决方法: 每一次需要等待确认时建立一个线程,按照固定间隔时间发送消息。将这个线程用哈希表存储起来,键为消息id,值为线程(的引用),当收到确认消息时,按照id来终止线程执行。
//WebSocketServer.java
//...
private Hashtable<Long, AckWaitingThread> waitingMissions;
@OnOpen
public void onOpen(Session session, @PathParam("token") String token){
waitingMissions = new Hashtable<>();
//用户连接上Websocket客户端后,会调用该函数
CommonResponse response = tokenUtil.tokenCheck(token);
this.session = session;
if(!response.getSuccess()){
sendMessage("loginFail");
return;
}
String targetId = response.getMessage(); //message里存储用户Id
if(!websocketMap.containsKey(targetId)){
onlineCount ++;
}
this.userId = Long.parseLong(targetId);
Optional <User> uop = userRepository.findUserById(userId);
currentUser = uop.get();
SocketDomain socketDomain = new SocketDomain();
socketDomain.setSession(session);
socketDomain.setUri(session.getRequestURI().toString());
websocketMap.put(String.valueOf(userId), socketDomain);
isAckSay = true;
logger.info("id为" + userId + "的用户连接,当前人数为" + onlineCount);
//startTraining(targetId); //
//两个序号的初始化
sayPackageIdExcepted = 1;
aiCurrentPackageId = 1;
RedisUtils.add(String.valueOf(currentUser.getId()),new ArrayList<>(),redisTemplateUserRoom);
//5.11 新增: AI开场白(固定)
String openingMessage = "say#" + aiCurrentPackageId + "#" + "Hello, I'm your oral English speaking training assistant. What can I help you?";
sendMessageWithResend(openingMessage);
redisStor(AI_USER.getId(), openingMessage);
aiCurrentPackageId ++;
}
@OnClose
public void onClose(){
if(websocketMap.containsKey(userId)){
websocketMap.remove(userId);
onlineCount --;
logger.info("id为" + userId + "的用户断开连接,当前人数为" + onlineCount);
}
for(Map.Entry<Long, AckWaitingThread> t: waitingMissions.entrySet()){
t.getValue().interrupt();
}
waitingMissions.clear();
//断开连接后,会把数据存到数据库
List<DialogDataBody> dialogsRaw = RedisUtils.getList(String.valueOf(userId),redisTemplateUserRoom,DialogDataBody.class);
List<Dialog> dialogs = new ArrayList<>();
for(DialogDataBody d: dialogsRaw){
dialogs.add(d.toDialog(currentUser));
}
//5.11: 空的训练不需要存
if(dialogs.size() == 0)
return;
//将dialog存入数据库
dialogRepository.saveAll(dialogs);
TrainingData trainingData = new TrainingData();
//5.11: 训练数据初始化需指定id
trainingData.setId("" + snowFlake.nextId());
trainingData.setUser(userRepository.findUserById(userId).get());
Date currentTimeUser = new Date();
trainingData.setTime(currentTimeUser);
trainingData.setDialogs(dialogs);
//调用接口,获取score
//将trainingData存入数据库
trainingDataRepository.save(trainingData);
redisTemplateUserRoom.delete(String.valueOf(userId));
}
@OnMessage
public void onMessage(String message, Session session){
//ack#say#{AI消息序号},用户向Websocket客户端发送消息,会调用该函数
String[] request;
request = message.split("#");
//后续存储在MySQL里的时候,只需要判断dialog的sender是否为AI就可以吧
if(request[0].equals("ack")){
if(request[1].equals("say")){
//格式ack#say#{AI消息序号} 发过去的回应被收到了
isAckSay = true;//接收到了ack才能继续
try{
Long ackPackageId = Long.parseLong(request[2]);
waitingMissions.get(ackPackageId).interrupt();
waitingMissions.remove(ackPackageId);
}catch (Exception ignored){}
}else{
System.out.println("前端发送错误: " + message);
throw new RuntimeException();
}
}else if(request[0].equals("say")&&isAckSay&&isAckCorr){
//格式say#{用户消息序号}#{用户回应}#{用户回应更正}#{评分},用户传来文本,需要调用大模型
/*2024.4.28 增加一层对用户消息序号的验证
*
* */
//将消息确认放最前面,避免后面的say与corr导致ack超时引起客户端重传
//收到用户的say后向用户回确认
String content = request[2];
String correction = request[3];
Double score = 0.0;
try{
score = Double.parseDouble(request[4]);
}catch (Exception ignored){};
String ackMessage = "ack#say#" + request[1];
sendMessageTo(String.valueOf(userId),ackMessage);
if(Long.parseLong(request[1]) >= sayPackageIdExcepted){
//TODO: 考虑序号小的包比序号大的包来的晚(但这种情况似乎不太可能)
sayPackageIdExcepted ++;
String messageToUser;
//存储大模型纠错作为一个dialog
isAckCorr = false;
startTimeCorr = System.currentTimeMillis();
//先把用户的文本存一个dialog
redisStor(userId,content,correction, score);
/*
* 2024.4.28
* 接入大模型 纠错暂不接入,中期检查后再更新
* 针对你Redis的数据结构,对AI接口也做了修改
* aiCurrentPackageId代表当前发送的对话的序号,作为标识
* 避免重传可能引起的重复问题
* */
List<DialogDataBody> dialogs = RedisUtils.getList(String.valueOf(userId),redisTemplateUserRoom,DialogDataBody.class);
String contentAI = aiService.invokeModel(currentUser, dialogs);
redisStor(AI_USER.getId(),contentAI);
//向用户发送AI回复
messageToUser = "say#" + aiCurrentPackageId + "#" + contentAI;
sendMessageWithResend(messageToUser);
aiCurrentPackageId ++;
/*
* 修改结束
* */
isAckSay = false;
}
}
if(!StringUtil.isNullOrEmpty(message)){
logger.info("收到id为" + userId + "的用户发来消息:" + message);
}
}
private void sendMessageWithResend(String content){
AckWaitingThread ackWaitingThread = new AckWaitingThread(this, content);
waitingMissions.put(aiCurrentPackageId, ackWaitingThread);
ackWaitingThread.start();
}
//...
//AckWaitingThread.java
public class AckWaitingThread extends Thread{
WebSocketServer webSocketServer;
String sendContent;
AckWaitingThread(WebSocketServer webSocketServer, String sendContent){
this.webSocketServer = webSocketServer;
this.sendContent = sendContent;
}
public void run(){
try{
while (true){
webSocketServer.sendMessage(sendContent);
sleep(webSocketServer.ACK_TIMEOUT);
}
}catch (InterruptedException ignored){}
}
}
3 细节问题: 在向mySQL中存储时未分配id、redisUtil中方法调用错误。