pom.xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
MessageHandler
package com.shareworx.websocket;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.shareworx.model.WebsocketMessageModel;
import com.shareworx.platform.util.DateUtil;
import com.shareworx.pojo.Message;
import com.shareworx.service.WebsocketMessageBusinessService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
@Component
public class MessageHandler extends TextWebSocketHandler {
@Autowired
private static final ObjectMapper MAPPER = new ObjectMapper();
private static final Map<Long, WebSocketSession> SESSIONS = new HashMap<>();
@Autowired
private WebsocketMessageBusinessService websocketMessageBusinessService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
Long uid = (Long) session.getAttributes().get("uid");
String uname = (String) session.getAttributes().get("uname");
// 将当前用户的session放置到map中,后面会使用相应的session通信
SESSIONS.put(uid, session);
System.out.println(uid+"的连接已经建立了..............用户名:"+uname);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage
textMessage) throws Exception {
Long uid = (Long) session.getAttributes().get("uid");
JsonNode jsonNode = MAPPER.readTree(textMessage.getPayload());
Long toId = jsonNode.get("toId").asLong();
String msg = jsonNode.get("msg").asText();
String date = DateUtil.formatDateTime(new Date());
Message message = Message.builder()
.fromId(uid)
.toId(toId)
.msg(msg)
.sendDate(date)
.build();
//保存到数据库
WebsocketMessageModel websocketMessageModel = new WebsocketMessageModel();
websocketMessageModel.setDeleteFlag(0);
websocketMessageModel.setFromId(uid);
websocketMessageModel.setToId(toId);
websocketMessageModel.setMsg(msg);
websocketMessageModel.setStatus(1);
websocketMessageModel.setSendDate(date);
websocketMessageBusinessService.save(websocketMessageModel);
message.setId(websocketMessageModel.getId());
String msgJson = MAPPER.writeValueAsString(message);
// 判断to用户是否在线
WebSocketSession toSession = SESSIONS.get(toId);
if (toSession != null && toSession.isOpen()) {
toSession.sendMessage(new TextMessage(msgJson));
} else {
//todo 该用户可能下线,可能在其他的节点中,发送消息到MQ系统
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
Long uid = (Long) session.getAttributes().get("uid");
SESSIONS.remove(uid);
System.out.println(uid+"的连接已经断开");
}
public void onMessage(String msg) {
try {
JsonNode jsonNode = MAPPER.readTree(msg);
Long toId = jsonNode.get("toId").asLong();
// 判断to用户是否在线
WebSocketSession toSession = SESSIONS.get(toId);
if (toSession != null && toSession.isOpen()) {
toSession.sendMessage(new TextMessage(msg));
} else {
// 不需要做处理
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 给某个用户发送消息
*
* @param userId
* @param msg
*/
public void sendMessageToUser(Long userId, Message msg) {
for (Map.Entry<Long, WebSocketSession> longWebSocketSessionEntry : SESSIONS.entrySet()) {
Long key = longWebSocketSessionEntry.getKey();
WebSocketSession session = longWebSocketSessionEntry.getValue();
Long toId = msg.getToId();
if ( toId!=null && toId.equals(userId)) {
try {
if (session.isOpen()) {
//保存到数据库
WebsocketMessageModel websocketMessageModel = new WebsocketMessageModel();
websocketMessageModel.setDeleteFlag(0);
websocketMessageModel.setFromId(msg.getFromId());
websocketMessageModel.setToId(msg.getToId());
websocketMessageModel.setMsg(msg.getMsg());
websocketMessageModel.setStatus(1);
websocketMessageModel.setSendDate(msg.getSendDate());
websocketMessageBusinessService.save(websocketMessageModel);
msg.setId(websocketMessageModel.getId());
session.sendMessage(new TextMessage(MAPPER.writeValueAsString(msg)));
}
} catch (IOException e) {
e.printStackTrace();
}
break;
}
}
}
/**
* 给所有在线用户发送消息
*
* @param message
*/
public void sendMessageToUsers(TextMessage message) {
for (Map.Entry<Long, WebSocketSession> longWebSocketSessionEntry : SESSIONS.entrySet()) {
WebSocketSession session = longWebSocketSessionEntry.getValue();
try {
if (session.isOpen()) {
session.sendMessage(message);
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
说明
通过继承 TextWebSocketHandler 类并覆盖相应方法,可以对 websocket 的事件进行处理,这里可以同原生注解的那几个注解连起来看
- afterConnectionEstablished 方法是在 socket 连接成功后被触发,同原生注解里的 @OnOpen 功能
- **afterConnectionClosed **方法是在 socket 连接关闭后被触发,同原生注解里的 @OnClose 功能
- **handleTextMessage **方法是在客户端发送信息时触发,同原生注解里的 @OnMessage 功能
MessageHandshakeInterceptor
package com.shareworx.websocket;
import com.shareworx.service.IUserService;
import com.shareworx.util.JWTUtils;
import io.jsonwebtoken.Claims;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.util.Map;
@Component
public class MessageHandshakeInterceptor implements HandshakeInterceptor {
@Autowired
private IUserService service;
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
try {
String token = ((ServletServerHttpRequest) request).getServletRequest().getParameter("token");
if (StringUtils.isBlank(token)) {
return false;
}
Claims claims = JWTUtils.getClaims(token);
if (claims == null) {
return false;
}
String username = claims.get("username").toString();
if (StringUtils.isBlank(username)) {
return false;
}
Long id = service.getUserIdByUserName(username);
if (id == null) {
return false;
}
attributes.put("uid", id);
attributes.put("uname", username);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
}
}
说明
通过实现 HandshakeInterceptor 接口来定义握手拦截器,注意这里与上面 Handler 的事件是不同的,这里是建立握手时的事件,分为握手前与握手后,而 Handler 的事件是在握手成功后的基础上建立 socket 的连接。所以在如果把认证放在这个步骤相对来说最节省服务器资源。它主要有两个方法 beforeHandshake 与 **afterHandshake **,顾名思义一个在握手前触发,一个在握手后触发。
WebSocketConfig
package com.shareworx.websocket;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
private MessageHandler messageHandler;
@Autowired
private MessageHandshakeInterceptor messageHandshakeInterceptor;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(this.messageHandler, "/ws")
.setAllowedOrigins("*")
.addInterceptors(this.messageHandshakeInterceptor);
}
}
说明
通过实现 WebSocketConfigurer 类并覆盖相应的方法进行 websocket 的配置。我们主要覆盖 registerWebSocketHandlers 这个方法。通过向 WebSocketHandlerRegistry 设置不同参数来进行配置。其中 **addHandler **方法添加我们上面的写的 ws 的 handler 处理类,第二个参数是你暴露出的 ws 路径。**addInterceptors **添加我们写的握手过滤器。**setAllowedOrigins("*") **这个是关闭跨域校验,方便本地调试,线上推荐打开。