1.maven坐标
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
2.创建处理器
@Component
public class CustomWebSocketHandler extends TextWebSocketHandler {
private static final Logger logger = LoggerFactory.getLogger(CustomWebSocketHandler.class);
public static final ConcurrentHashMap<String, WebSocketSession> WEB_SOCKET_SESSION_MAP = new ConcurrentHashMap<>();
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
logger.info("接受到消息【{}】的消息:{}", session.getId(), message.getPayload());
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String sessionId = getSessionId(session);
if (WEB_SOCKET_SESSION_MAP.containsKey(sessionId)) {
WEB_SOCKET_SESSION_MAP.get(sessionId).close();
}
WEB_SOCKET_SESSION_MAP.put(sessionId, session);
logger.info("与【{}】建立了连接", sessionId);
sendMessage(sessionId, sessionId);
logger.info("attributes:{}", session.getAttributes());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
logger.info("连接对象【{}】断开连接,status:{}", getSessionId(session), status.getCode());
session.close(CloseStatus.SERVER_ERROR);
WEB_SOCKET_SESSION_MAP.remove(getSessionId(session));
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
logger.info("连接对象【{}】发生错误,exception:{}", session.getId(), exception.getMessage());
if (session.isOpen()) {
session.close();
}
WEB_SOCKET_SESSION_MAP.remove(getSessionId(session));
}
private String getSessionId(WebSocketSession session) {
return (String) session.getAttributes().get("username");
}
public void sendMessage(String sessionId, String message) throws IOException {
WebSocketSession webSocketSession = WEB_SOCKET_SESSION_MAP.get(sessionId);
if (webSocketSession == null || !webSocketSession.isOpen()) {
logger.warn("连接对象【{}】已关闭,无法送消息:{}", sessionId, message);
} else {
webSocketSession.sendMessage(new TextMessage(message));
logger.info("sendMessage:向{}发送消息:{}", sessionId, message);
}
}
public void sendMessage(String sessionId, Object data) throws IOException {
sendMessage(sessionId, JSON.toJSONString(data));
}
public List<String> getSessionIds() {
Enumeration<String> keys = WEB_SOCKET_SESSION_MAP.keys();
List<String> ks = new ArrayList<>();
while (keys.hasMoreElements()) {
ks.add(keys.nextElement());
}
return ks;
}
}
3.创建拦截器
@Component
public class CustomWebsocketInterceptor extends HttpSessionHandshakeInterceptor {
private static final Logger logger = LoggerFactory.getLogger(CustomWebsocketInterceptor.class);
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
ServletServerHttpRequest req = (ServletServerHttpRequest) request;
ServletServerHttpResponse res = (ServletServerHttpResponse) response;
String token = req.getServletRequest().getParameter("token");
String username = req.getServletRequest().getParameter("username");
logger.info("建立连接....token:{} username:{}", token, username);
logger.info("attributes:{}", attributes);
attributes.put("token", token);
attributes.put("username", username);
super.setCreateSession(true);
return super.beforeHandshake(request, response, wsHandler, attributes);
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
logger.info("连接成功....");
super.afterHandshake(request, response, wsHandler, exception);
}
}
3.创建配置文件
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Resource
private CustomWebsocketInterceptor customWebsocketInterceptor;
@Resource
private CustomWebSocketHandler customWebSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry
.addHandler(customWebSocketHandler,"/custom")
.setAllowedOrigins("*")
.addInterceptors(customWebsocketInterceptor);
}
}