后端websocket代码
用的2.5.3版本的springboot
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.Map;
public class CustomHandshakeInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
String token = extractTokenFromRequest(request);
if (StringUtils.isBlank(token)) { // 如果token为空
response.setStatusCode(HttpStatus.UNAUTHORIZED); // 设置状态码为401 Unauthorized
response.getBody().flush(); // 发送响应
return false; // 拒绝握手
}
attributes.put("token", token); // 将token放入attributes
return true; // 允许握手
}
private String extractTokenFromRequest(ServerHttpRequest request) {
UriComponentsBuilder builder = UriComponentsBuilder.fromHttpRequest(request);
return builder.build().getQueryParams().getFirst("token");
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
// 握手完成后的操作,如果有的话
}
}
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Slf4j
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 注册WebSocket端点,并设置自定义的HandshakeInterceptor
registry.addHandler(websocketHandlerBean(), "/websocket")
.setAllowedOrigins("*")
.addInterceptors(new CustomHandshakeInterceptor());
}
public WebSocketHandler websocketHandlerBean() {
return new WebSocketTest(); // 创建WebSocketTest的实例作为WebSocketHandler
}
}
package cn.js.selfstudy.config.websocket;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Map;
import java.util.concurrent.*;
@Component
@EnableScheduling
@Slf4j
public class WebSocketTest implements WebSocketHandler {
private static final int MAX_CONNECTIONS = 30;//session 最大连接数
private static final long MAX_SESSION_TIMEOUT = 1000 * 60 * 2;//session 最大连接时长(单位:毫秒)
private static final DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss:SSS");
private static final BlockingQueue<WebSocketSession> sessionList = new LinkedBlockingQueue<>(MAX_CONNECTIONS);//记录session
private static final ConcurrentHashMap<String, String> sessionTokenMap = new ConcurrentHashMap<>();// 记录session的token
private static final ConcurrentHashMap<String, Long> sessionStartTimeMap = new ConcurrentHashMap<>();// 记录session开始时间
@Override
public void afterConnectionEstablished(WebSocketSession session) {
String token = (String) session.getAttributes().get("token");
if (StringUtils.isBlank(token)) {//可根据需要添加token的验证
closeSession(session, "未传token,拒绝连接");// 未传token,拒绝连接
return;
}
if (!sessionList.offer(session)) {
closeSession(session, "队列已满,拒绝连接");// 队列已满,拒绝连接
return;
}
sessionTokenMap.put(session.getId(), token);
sessionStartTimeMap.put(session.getId(), System.currentTimeMillis());
log.info("[WebSocket] 新连接建立 sessionId:{},token:{},当前连接数:{}", session.getId(), token, sessionList.size());
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) {
log.info("[WebSocket] 收到来自客户端的消息 sessionId:{} 消息:{}", session.getId(), message.getPayload());
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) {
log.error("[WebSocket] 传输错误 sessionId:{}", session.getId(), exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) {
long duration = getSessionDuration(session.getId());
sessionList.remove(session);// 清除session记录
sessionTokenMap.remove(session.getId());// 清除token记录
sessionStartTimeMap.remove(session.getId()); // 清除开始时间记录
log.info("[WebSocket] 连接关闭 sessionId:{},会话持续时间:{} 秒, 当前连接数:{}", session.getId(), duration, sessionList.size());
}
@Override
public boolean supportsPartialMessages() {
return false;
}
@Scheduled(cron = "*/30 * * * * *")
private void sendScheduledMessages() {
log.info("[WebSocket] 准备发送消息 sessionsSize:{}", sessionList.size());
String message = "This is a scheduled message: " + LocalDateTime.now().format(dateFormat);
for (WebSocketSession webSocketSession : sessionList) {
if (webSocketSession.isOpen()) {
try {
log.info("[WebSocket] sessionId:{} sessionTokenMap:{}", webSocketSession.getId(), sessionTokenMap.size());
webSocketSession.sendMessage(new TextMessage(message));
} catch (IOException e) {
log.error("[WebSocket] 发送消息失败", e);
}
}
}
}
@Scheduled(cron = "*/10 * * * * *")
private void checkSessionTimeouts() {
log.info("[WebSocket] session超时检查 sessionsSize:{}", sessionList.size());
if (CollectionUtils.isEmpty(sessionList)) {
return;
}
long currentTime = System.currentTimeMillis();
for (Map.Entry<String, Long> entry : sessionStartTimeMap.entrySet()) {
if (entry.getValue() + MAX_SESSION_TIMEOUT < currentTime) {
// 找到超时的会话,可以关闭它
WebSocketSession session = sessionList.stream()
.filter(s -> s.getId().equals(entry.getKey()))
.findFirst()
.orElse(null);
if (session != null && session.isOpen()) {
closeSession(session, "超过连接最大时长,关闭连接");
}
}
}
}
//关闭连接
private void closeSession(WebSocketSession session, String reason) {
try {
session.close();
log.info("[WebSocket] 关闭连接 sessionId :{},reason:{}", session.getId(), reason);
} catch (IOException e) {
log.error("[WebSocket] 关闭连接异常,e:{}", e);
}
}
//获取session连接时长(秒)
private long getSessionDuration(String sessionId) {
Long startTime = sessionStartTimeMap.get(sessionId);
if (startTime == null) {
return -1; // 如果没有找到开始时间,返回-1或其他错误码
}
return (System.currentTimeMillis() - startTime) / 1000;
}
}
测试页面:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket Test</title>
<script>
var ws;
document.addEventListener('DOMContentLoaded', function() {
var token = 'test001'; // 替换为你的token
var url = 'ws://localhost:9998/websocket?token=' + encodeURIComponent(token);
ws = new WebSocket(url);
ws.onopen = function() {
console.log('WebSocket 连接已打开。');
};
ws.onmessage = function(event) {
console.log('收到消息:', event.data);
document.getElementById('messages').textContent += event.data + '\n';
};
ws.onerror = function(error) {
console.error('WebSocket 出现错误:', error);
};
ws.onclose = function() {
console.log('WebSocket 连接已关闭。');
};
});
</script>
</head>
<body>
<h5>WebSocket Messages</h5>
<pre id="messages"></pre>
</body>
</html>
代码基于Spring Boot 2.5.3版本的WebSocket服务实现。包括了后端的配置、拦截器、WebSocket处理器,以及前端的测试页面。
-
依赖配置:在
pom.xml
中添加了spring-boot-starter-websocket
依赖,用于支持WebSocket功能。 -
CustomHandshakeInterceptor:自定义握手拦截器,用于在WebSocket握手过程中验证token。如果token为空或无效,则拒绝握手。
-
WebSocketConfig:配置类,使用
@EnableWebSocket
注解启用WebSocket支持,并注册WebSocket端点,同时指定自定义的HandshakeInterceptor
。 -
WebSocketTest:实现了
WebSocketHandler
接口的类,用于处理WebSocket连接的生命周期事件,如连接建立、消息接收、连接关闭等。同时,它还维护了一个会话集合和一个定时任务,用于定时向客户端发送消息。 -
测试页面:一个简单的HTML页面,用于测试WebSocket连接。使用JavaScript创建WebSocket连接,并处理打开、消息、错误和关闭事件。