websocket服务端增加定时发送心跳机制
@ServerEndpoint(value = "/websocket/{uuid}")
@Component
public class DevMessageHandleController {
private static final Logger logger = LoggerFactory.getLogger(DevMessageHandleController.class);
public static CopyOnWriteArraySet<DevMessageHandleController> webSocketSet = new CopyOnWriteArraySet<>();
private static ConcurrentHashMap<String, DevMessageHandleController> webSocketMap = new ConcurrentHashMap<>();
private Session session;
private String uuid;
private AtomicInteger heartbeatAttempts;
@OnOpen
public void onOpen(@PathParam("uuid") String uuid, Session session) {
logger.info("uuid: {}, sessionId: {}", uuid, session.getId());
try {
if (webSocketMap.containsKey(uuid)) {
webSocketMap.get(uuid).session.close();
webSocketSet.remove(webSocketMap.get(uuid));
}
this.session = session;
this.uuid = uuid;
heartbeatAttempts = new AtomicInteger(0);
webSocketSet.add(this);
webSocketMap.put(uuid, this);
} catch (Exception e) {
logger.error("onOpen error:" + e.getMessage());
}
}
@OnClose
public void onClose(@PathParam("uuid") String uuid, Session session) {
logger.info("会话关闭");
webSocketSet.remove(this);
webSocketMap.remove(uuid);
}
@OnMessage
public void onMessage(String message, Session session) {
logger.info("Message from client: " + message);
if ("pong".equals(message)) {
this.heartbeatAttempts.set(0);
System.out.println("Received pong from: " + session.getId());
}
}
@OnError
public void onError(Session session, Throwable error) {
logger.error("发生错误 session:" + session.getId() + ",error:" + error);
try {
session.close();
webSocketSet.remove(this);
webSocketMap.remove(this.uuid);
} catch (IOException e) {
logger.error("onError error:" + e.getMessage());
}
}
public void sendMessage(Session session, String msg) {
logger.info("发送消息");
try {
if (session.isOpen()) {
session.getAsyncRemote().sendText(msg);
} else {
session.close();
webSocketSet.remove(this);
webSocketMap.remove(this.uuid);
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static CopyOnWriteArraySet<DevMessageHandleController> getWebSocketSet() {
return webSocketSet;
}
public static void setWebSocketSet(CopyOnWriteArraySet<DevMessageHandleController> webSocketSet) {
DevMessageHandleController.webSocketSet = webSocketSet;
}
public static ConcurrentHashMap<String, DevMessageHandleController> getWebSocketMap() {
return webSocketMap;
}
public static void setWebSocketMap(ConcurrentHashMap<String, DevMessageHandleController> webSocketMap) {
DevMessageHandleController.webSocketMap = webSocketMap;
}
public Session getSession() {
return session;
}
public void setSession(Session session) {
this.session = session;
}
public String getUuid() {
return uuid;
}
public void setUuid(String uuid) {
this.uuid = uuid;
}
public AtomicInteger getHeartbeatAttempts() {
return heartbeatAttempts;
}
public void setHeartbeatAttempts(AtomicInteger heartbeatAttempts) {
this.heartbeatAttempts = heartbeatAttempts;
}
}
每间隔10s向客户端发送一次心跳
private static final int MAX_HEARTBEAT_ATTEMPTS = 3;
@Scheduled(fixedDelay = 10000)
public void sendHeartBeat() {
CopyOnWriteArraySet<DevMessageHandleController> webSocketSet;
try {
webSocketSet = DevMessageHandleController.getWebSocketSet();
logger.info("连接数量:" + webSocketSet.size());
if(webSocketSet.size() == 0){
return;
}
logger.info("定时发送心跳");
webSocketSet.forEach(obj -> {
Session session = obj.getSession();
logger.info("sessionId:" + session.getId() +" 心跳ping发送次数:" + obj.getHeartbeatAttempts().get());
if(obj.getHeartbeatAttempts().get() >= MAX_HEARTBEAT_ATTEMPTS) {
try {
session.close();
} catch (IOException e) {
e.printStackTrace();
logger.error("session close error:" + e.getMessage());
}
} else {
obj.getHeartbeatAttempts().incrementAndGet();
if (session.isOpen()) {
session.getAsyncRemote().sendText("ping");
}
}
});
}catch (Exception e){
logger.error("发送心跳 error:" + e.getMessage());
}
}