websocket实现有四种方法,这里展示相对简单的一种方法
这里使用原生注解的方法实现
说明需要注意的点:
- 引用的包都在 **javax.websocket **下。并不是 spring 提供的,而 jdk 自带的。
下面关于使用到的几个注解的说明:
- @ServerEndpoint :通过这个 spring boot 就可以知道你暴露出去的 ws 应用的路径,有点类似我们经常用的@RequestMapping。比如你的启动端口是 8080,而这个注解的值是 ws,那我们就可以通过 ws://127.0.0.1:8080/ws 来连接你的应用
- @OnOpen:当 websocket 建立连接成功后会触发这个注解修饰的方法,注意它有一个 Session 参数
- @OnClose: 当 websocket 建立的连接断开后会触发这个注解修饰的方法
- @OnMessage: 当客户端发送消息到服务端时,会触发这个注解修改的方法,如果需要做心跳检测可以在这里做。
- @OnError::当 websocket 建立连接时出现异常会触发这个注解修饰的方法
- 使用 session.getBasicRemote().sendText(*) 向客户端发送消息
如下是具体实现
- pom.xml 引入架包
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
2.WebSocketConfig
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
/**
* 开启WebSocket
* WebScoket配置
* 通过这个配置 spring boot 才能去扫描后面的关于 websocket 的注解
*/
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
3.WebSocketServer
package com.dw.sprboosoc.service;
import com.alibaba.fastjson.JSON;
import com.dw.sprboosoc.constant.MessageEnum;
import com.dw.sprboosoc.dto.WebSocketMessageDto;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
// 添加Bean
@Component
@Slf4j
//访问路径
@ServerEndpoint(value = "/websocket/{sendUserId}")
public class WebSocketServer {
//静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
private static final AtomicInteger currentOnlineNumber = new AtomicInteger();
//concurrent包的线程安全Set,用来存放每个客户端对应的WebSocketServer对象
private static final ConcurrentHashMap<String, Session> sessionPool = new ConcurrentHashMap<>();
//设置为静态的 公用一个消息map ConcurrentMap为线程安全的map HashMap不安全
private static final ConcurrentMap<String, Map<String, List<WebSocketMessageDto>>> messageMap = new ConcurrentHashMap<>();
/*
*发送消息
* @param [session, message, userId]
* @return void
*/
public void sendMessage(WebSocketMessageDto webSocketMessageDto) throws IOException {
try {
switch (webSocketMessageDto.getMessageEnum()) {
// 广播消息
case ALL:
sessionPool.values().forEach(se -> {
try {
se
.getBasicRemote()
.sendText(webSocketMessageDto.toString());
} catch (IOException e) {
log.info("群发");
}
});
log.info("websocket: 广播消息:" + webSocketMessageDto);
// 离线
storeOfflineMessage(webSocketMessageDto);
break;
// 私发
case ONE:
//根据约定的字段,检测心跳
String message = webSocketMessageDto.getMessage();
if (message.equals("ping")) {
WebSocketMessageDto wd = new WebSocketMessageDto();
wd.setMessage("pong");
wd.setRecvUserId("");
wd.setMessageEnum(MessageEnum.ONE);
sessionPool.get(webSocketMessageDto.getSendUserId())
.getBasicRemote()
.sendText(JSON.toJSONString(wd));
} else {
if (judgeUserOnline(webSocketMessageDto.getRecvUserId())) {
// 在线
/*sessionPool.get(webSocketMessageDto.getRecvUserId())
.getBasicRemote()
.sendText(webSocketMessageDto.toString());*/
sessionPool.get(webSocketMessageDto.getRecvUserId())
.getBasicRemote()
.sendText(JSON.toJSONString(webSocketMessageDto));
} else {
// 离线
storeOfflineMessage(webSocketMessageDto);
}
log.info("websocket: 私发消息," + webSocketMessageDto);
break;
}
}
} catch (Exception exception) {
log.error("websocket: 发送消息发生了错误");
}
}
/*
*客户端收到消息
* @param [message]
* @return void
*/
@OnMessage
public void onMessage(String webSocketMessageDtoStr) throws IOException {
WebSocketMessageDto webSocketMessageDto = JSON.parseObject(webSocketMessageDtoStr, WebSocketMessageDto.class);
log.info("websocket:" + webSocketMessageDto.getRecvUserId() + "收到,来自:" + webSocketMessageDto.getSendUserId() + ",发送的消息:" + webSocketMessageDto.getMessage());
sendMessage(webSocketMessageDto);
}
/*
*判断用户是否在线
* @param [recvUserId]
* @return boolean
*/
public boolean judgeUserOnline(String recvUserId) {
boolean flag = !ObjectUtils.isEmpty(sessionPool.get(recvUserId));
String flagStr = flag ? "在线" : "离线";
log.info("websocket: " + recvUserId + ":" + flagStr);
return flag;
}
/*
*用户离线时把消息存储到内存
* @param [recvUserId]
* @return void
*/
public void storeOfflineMessage(WebSocketMessageDto webSocketMessageDto) {
//用户不在线时 第一次给他发消息
if (ObjectUtils.isEmpty(messageMap.get(webSocketMessageDto.getRecvUserId()))) {
Map<String, List<WebSocketMessageDto>> maps = new HashMap<>();
List<WebSocketMessageDto> list = new ArrayList<>();
list.add(webSocketMessageDto);
maps.put(webSocketMessageDto.getRecvUserId(), list);
messageMap.put(webSocketMessageDto.getRecvUserId(), maps);
} else {
//用户不在线时 再次发送消息
Map<String, List<WebSocketMessageDto>> listObject = messageMap.get(webSocketMessageDto.getRecvUserId());
List<WebSocketMessageDto> objects = new ArrayList<>();
if (!ObjectUtils.isEmpty(listObject.get(webSocketMessageDto.getRecvUserId()))) {//这个用户给收消息的这个用户发过消息
//此用户给该用户发送过离线消息(此用户给该用户发过的所有消息)
objects = listObject.get(webSocketMessageDto.getRecvUserId());
//加上这次发送的消息
objects.add(webSocketMessageDto);
//替换原来的map
listObject.put(webSocketMessageDto.getRecvUserId(), objects);
} else {//这个用户没给该用户发送过离线消息
objects.add(webSocketMessageDto);
listObject.put(webSocketMessageDto.getRecvUserId(), objects);
}
messageMap.put(webSocketMessageDto.getRecvUserId(), listObject);
}
}
/*
*成功建立连接后调用
* @param [session, userId]
* @return void
*/
@OnOpen
public void onOpen(Session session, @PathParam(value = "sendUserId") String sendUserId) throws IOException {
//成功建立连接后加入
sessionPool.put(sendUserId, session);
//当前在线数量+1
currentOnlineNumber.incrementAndGet();
log.info("websocket:" + sendUserId + "加入连接,当前在线用户" + currentOnlineNumber + "未读消息数:" + getMessageCount(sendUserId));
// 发送离线消息
sendOffLineMessage(sendUserId);
}
/*
* 用户上线时发送离线消息
* @param []
* @return void
*/
@SneakyThrows
public void sendOffLineMessage(String sendUserId) {
if (ObjectUtils.isEmpty(messageMap.get(sendUserId))) {
// 该用户没有离线消息
return;
}
// 当前登录用户有离线消息
//说明在用户没有登录的时候有人给用户发送消息
//该用户所有未收的消息
Map<String, List<WebSocketMessageDto>> lists = messageMap.get(sendUserId);
//对象用户发送的离线消息
List<WebSocketMessageDto> list = lists.get(sendUserId);
if (list != null) {
for (WebSocketMessageDto webSocketMessageDto : list) {
onMessage(JSON.toJSONString(webSocketMessageDto));
}
}
// 删除已发送的消息
removeHasBeenSentMessage(sendUserId, lists);
}
/*
*删除已发送的消息
* @param [sendUserId, map]
* @return void
*/
public void removeHasBeenSentMessage(String sendUserId, Map<String, List<WebSocketMessageDto>> map) {
// map中key(键)的迭代器对象
//用户接收完消息后删除 避免下次继续发送
Iterator iterator = map.keySet().iterator();
while (iterator.hasNext()) {// 循环取键值进行判断
String keys = (String) iterator.next();//键
if (sendUserId.equals(keys)) {
iterator.remove();
}
}
}
/*
*关闭连接时调用
* @param [userId]
* @return void
*/
@OnClose
public void onClose(@PathParam(value = "sendUserId") String sendUserId) {
sessionPool.remove(sendUserId);
currentOnlineNumber.decrementAndGet();
log.info("websocket:" + sendUserId + "断开连接,当前在线用户" + currentOnlineNumber);
}
/*
*发生错误时调用
* @param [session, throwable]
* @return void
*/
@OnError
public void onError(Throwable throwable) {
log.error("websocket: 发生了错误");
throwable.printStackTrace();
}
/**
* 获取该用户未读的消息数量
*/
public int getMessageCount(String recvUserId) {
//获取该用户所有未收的消息
Map<String, List<WebSocketMessageDto>> listMap = messageMap.get(recvUserId);
if (listMap != null) {
List<WebSocketMessageDto> list = listMap.get(recvUserId);
if (list != null) {
return listMap.get(recvUserId).size();
} else {
return 0;
}
} else {
return 0;
}
}
}
4.WebSocket 消息 DTO
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import java.io.Serializable;
@Getter
@Setter
@ToString
public class WebSocketMessageDto implements Serializable {
private static final long serialVersionUID = 4153093005674764992L;
/*发信人*/
private String sendUserId;
/*收信人*/
private String recvUserId;
/*消息内容*/
private String message;
/*消息类型*/
private MessageEnum messageEnum;
}
5.MessageEnum 消息体枚举
import lombok.Getter;
@Getter
public enum MessageEnum {
/**
* 私发
*/
ONE("one", "私发"),
/**
* 群发
*/
ALL("all", "群发"),
/**
* 其他
*/
OTHER("other", "其他");
/**
* 消息类型
*/
private final String messageType;
/**
* 消息类型描述
*/
private final String desc;
/**
* 消息枚举 构造器 私有化
*
* @param messageType 消息类型
* @param desc desc
*/
MessageEnum(String messageType, String desc) {
this.messageType = messageType;
this.desc = desc;
}
}
6.结果验证:建立连接方式可以在 http://www.jsons.cn/websocket/ 页面输入 ws://127.0.0.1:端口号/websocket/1234 建立连接、发送消息等