本片文章实现的功能是后端程序给指定的用户通过webSocket推送消息。根据前端用户登录时给后台传递用户ID,然后在后台将用户ID和对应的webSocketSession存储起来,用于推送消息时使用。
一、pom.xml文件中加入maven依赖
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-websocket</artifactId>
<version>4.3.9.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>4.3.9.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
<version>4.3.9.RELEASE</version>
</dependency>
二、websocket配置类
package framework.websocket;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig extends WebMvcConfigurerAdapter implements WebSocketConfigurer {
/**
* 注入websocket处理类
*/
@Autowired
private WebSocketHandler webSocketHandler;
/**
* 注册SocketHandler
*/
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 原生支持websocket的使用地址/webSocketServer
registry.addHandler(webSocketHandler, "/webSocketService").addInterceptors(
new WebsocketHandshakeInterceptor()).setAllowedOrigins("*");
// 不支持websocket的使用sockjs,地址/webSocketServer/sockjs
registry.addHandler(webSocketHandler, "/webSocketService/sockjs").setAllowedOrigins("*")
.addInterceptors(new WebsocketHandshakeInterceptor()).withSockJS();
}
}
三、websocket握手拦截器(握手请求和响应,并将属性传递给目标)
package framework.websocket;
import java.util.Map;
import java.util.logging.Logger;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
public class WebsocketHandshakeInterceptor implements HandshakeInterceptor {
/**
* 获取日志对象(全类名)
*/
private static final Logger logger = Logger.getLogger(WebsocketHandshakeInterceptor.class.getName());
/**
* 在处理握手之前调用,这里是把获取的请求数据绑定到session的map对象中(attributes)
*/
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
// 将ServerHttpRequest转成ServletServerHttpRequest
ServletServerHttpRequest servletServerHttpRequest = (ServletServerHttpRequest) request;
// 获取请求参数,先获取HttpServletRequest对象才能获取请求参数
HttpServletRequest httpServletrequest = servletServerHttpRequest.getServletRequest();
logger.info(String.format("Websocket Handshake Interceptor, sessionID:%s",
httpServletrequest.getSession().getId()));
// 获取请求的数据
String clientID = httpServletrequest.getParameter(WebSocketConstant.CLIENT_ID);
if (null == clientID || "".equals(clientID)) {
return false;
}
// 把获取的请求数据绑定到session的map对象中(attributes)
attributes.put(WebSocketConstant.CLIENT_ID, clientID);
return true;
}
/*
* 握手完成后调用。响应状态和头指示握手的结果,即握手是否成功。
*/
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
// TODO Auto-generated method stub
}
}
四、webSocket处理类
package framework.websocket;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@Component
public class WebSocketHandler extends TextWebSocketHandler {
/**
* 获取日志对象(全类名)
*/
private static final Logger logger = Logger.getLogger(WebSocketHandler.class.getName());
/**
* 在线用户列表
*/
private static final Map<String, List<WebSocketSession>> onlineUsers;
static {
onlineUsers = Collections.synchronizedMap(new HashMap<>());
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 获取WebSocket客户端ID
String clientID = getWebSocketClientID(session);
List<WebSocketSession> webSocketSessionList = onlineUsers.get(clientID);
if (null != webSocketSessionList) {
webSocketSessionList.add(session);
} else {
webSocketSessionList = new ArrayList<>();
webSocketSessionList.add(session);
onlineUsers.put(clientID, webSocketSessionList);
}
recordLog("Websocket connection successful, wsClientID:%s,wsSessionID:%s。",
clientID, session.getId());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status)
throws Exception {
// 获取WebSocket客户端ID
String clientID = getWebSocketClientID(session);
List<WebSocketSession> webSocketSessionList = onlineUsers.get(clientID);
if (null != webSocketSessionList) {
webSocketSessionList.remove(session);
}
recordLog("Websocket disconnected successfully, wsClientID:%s,wsSessionID:%s。",
clientID, session.getId());
}
@Override
public void handleTextMessage(WebSocketSession session, TextMessage message) {
// 获取WebSocket客户端ID
String clientID = getWebSocketClientID(session);
recordLog("Websocket received the message successfully, wsClientID:%s,wsSessionID:%s。",
clientID, session.getId());
// 将前端接收的消息回发给客户端
sendMessage(clientID, session.getId(), message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) {
// 获取WebSocket客户端ID
String clientID = getWebSocketClientID(session);
List<WebSocketSession> webSocketSessionList = onlineUsers.get(clientID);
if (null != webSocketSessionList) {
webSocketSessionList.remove(session);
}
recordLog("Websocket handles connection exceptions, wsClientID:%s,wsSessionID:%s。",
clientID, session.getId());
}
/**
* 获取WebSocket客户端ID
* @param session WebSocketSession
* @return WebSocket客户端ID
*/
private String getWebSocketClientID(WebSocketSession session) {
return (String) session.getAttributes().get(WebSocketConstant.CLIENT_ID);
}
/**
* 发送消息
* @param clientID 客户端ID
* @param textMessage 消息对象
*/
private static void sendMessage(String clientID, String sessionID, TextMessage textMessage) {
// 根据客户端ID查询webSocketSession集合
List<WebSocketSession> webSocketSessionList = onlineUsers.get(clientID);
if (null == webSocketSessionList) {
// recordLog("Not find webSocketSession by clientID, wsClientID:%s。", clientID);
return;
}
for (WebSocketSession webSocketSession : webSocketSessionList) {
// 得到webSocketSessionID
String webSocketSessionID = webSocketSession.getId();
boolean flag = false;
// 当sessionID为null时,表示给所有根据clientID查询出来的客户端都发送;当sessionID不为null时,表示只给和sessionID相同的客户端单独发送;
if (null == sessionID || sessionID.equals(webSocketSessionID)) {
flag = true;
}
// 判断连接是否仍然处于连接状态
if (webSocketSession.isOpen()) {
if(flag) {
try {
webSocketSession.sendMessage(textMessage);
recordLog("Send webSocket message success, wsClientID:%s,wsSessionID:%s。",
clientID, webSocketSessionID);
} catch (IOException e) {
recordLog(
"Send webSocket message exception,wsClientID:%s,wsSessionID:%s,exceptionMessage:%s。",
clientID, webSocketSessionID, e.getMessage());
e.printStackTrace();
}
}
} else {
recordLog(
"Send webSocket message faile,client state is disconnected,wsClientID:%s,wsSessionID:%s。",
clientID, webSocketSessionID);
}
}
}
/**
* 发送信息
* @param clientID 客户端ID
* @param message 消息
*/
public static void sendMessage(String clientID, String message) {
sendMessage(clientID, null, new TextMessage(message));
}
/**
* 记录日志
* @param message 日志信息
* @param args 日志信息中的参数值
*/
private static void recordLog(String message, Object... args) {
logger.info(String.format(message, args));
}
}
五、websocket常量类
package framework.websocket;
public final class WebSocketConstant {
/**
* websocket客户端ID常量
*/
public static final String CLIENT_ID = "clientID";
}