springboot中websocket的使用
1. 添加依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
2. 添加配置类
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Bean
public WebSocketHandler sysWebSocketHandler(){
return new WebSocketHandler();
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(sysWebSocketHandler(), Constants.WebSocket.PATH_PREFIX + "/*")
.setAllowedOrigins("*");
}
}
3. 添加消息处理器
@Component
public class WebSocketHandler implements WebSocketHandler {
private static final Logger log = LoggerFactory.getLogger(WebSocketHandler.class);
@Resource
private IWebSocketService websocketService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
websocketService.handleOpen(session);
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
if (message instanceof TextMessage){
TextMessage textMessage = (TextMessage) message;
websocketService.handleMessage(session,textMessage.getPayload());
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
websocketService.handleError(session,exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
websocketService.handleClose(session);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
4.websocket service接口
package com.fuj.wms.common.core.service;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.util.Map;
public interface IWebSocketService {
void handleOpen(WebSocketSession session);
void handleClose(WebSocketSession session);
void handleMessage(WebSocketSession session, String message);
void sendMessage(WebSocketSession session, String message);
void sendMessage(String sessionId, TextMessage message);
void sendMessage(String sessionId, String message);
void sendMessage(WebSocketSession session, TextMessage message);
void broadCast(String message);
void broadCast(TextMessage message);
void broadCastExclude(String message,String... sessionIds);
void broadCastExclude(TextMessage message,String... sessionIds);
void handleError(WebSocketSession session, Throwable error);
void closeWebSocket(WebSocketSession session);
void closeWebSocket(String sessionId);
Map<String, WebSocketSession> getSessions();
int getConnectionCount();
String getRequest(WebSocketSession session);
String getHostAddress(WebSocketSession session);
String getHostName(WebSocketSession session);
}
5.websocket service接口实现类
package com.fuj.wms.common.core.service;
import com.fuj.wms.common.constant.CacheConstants;
import com.fuj.wms.common.constant.Constants;
import com.fuj.wms.common.core.redis.RedisCache;
import com.fuj.wms.common.utils.StringUtils;
import com.fuj.wms.common.utils.spring.SpringUtils;
import com.fuj.wms.common.core.domain.model.FujWebSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import java.io.IOException;
import java.util.Arrays;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class WebSocketService implements IWebSocketService {
private static final Logger log = LoggerFactory.getLogger(WebSocketService.class);
protected static AtomicInteger connectionCount = new AtomicInteger(0);
protected static Map<String, WebSocketSession> sessionsMap = new ConcurrentHashMap<>();
protected final int sendTimeLimit = 10 * 1000 * 1000;
protected final int bufferSizeLimit = 1024 * 1024 * 1024;
@Override
public void handleOpen(WebSocketSession session) {
log.info("WebSocket连接建立中 ==> session_id = {},IP地址 = {},PC名称 = {},uri = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session));
sessionsMap.put(session.getId(), new ConcurrentWebSocketSessionDecorator(session, sendTimeLimit, bufferSizeLimit));
connectionCount.incrementAndGet();
log.info("WebSocket连接建立成功,当前在线数为: {} ==> 开始监听新连接:session_id = {}。", connectionCount.get(), session.getId());
Map<String, FujWebSocket> webSocketMap = SpringUtils.getBean(RedisCache.class).getCacheMap(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_MAP_KEY);
if (StringUtils.isEmpty(webSocketMap)) {
webSocketMap = new ConcurrentHashMap<>();
SpringUtils.getBean(RedisCache.class).setCacheMap(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_MAP_KEY, webSocketMap);
}
SpringUtils.getBean(RedisCache.class).setCacheMapValue(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_MAP_KEY, session.getId(), createFujWebSocket(session));
Object count = SpringUtils.getBean(RedisCache.class).getCacheObject(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_COUNT_KEY);
if (StringUtils.isNull(count)) {
SpringUtils.getBean(RedisCache.class).setCacheObject(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_COUNT_KEY, Constants.Number.INTEGER_DEFAULT_VALUE);
}
SpringUtils.getBean(RedisCache.class).incr(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_COUNT_KEY);
}
@Override
public void handleClose(WebSocketSession session) {
log.info("WebSocket连接关闭中 ==> session_id = {},IP地址 = {},PC名称 = {},接口 = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session));
sessionsMap.remove(session.getId());
connectionCount.decrementAndGet();
log.info("WebSocket连接关闭成功,当前在线数为: {} ==> 已关闭连接:session_id = {}。", connectionCount.get(), session.getId());
SpringUtils.getBean(RedisCache.class).deleteCacheMapValue(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_MAP_KEY, session.getId());
SpringUtils.getBean(RedisCache.class).decr(CacheConstants.WMS_KEY + CacheConstants.Monitor.WEBSOCKET_COUNT_KEY);
}
@Override
public void handleMessage(WebSocketSession session, String message) {
log.info("WebSocket服务端收到客户端消息 ==> session_id = {},IP地址 = {},PC名称 = {},接口 = {},message = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session), message);
}
@Override
public void sendMessage(WebSocketSession session, String message) {
try {
log.info("WebSocket发送消息 ==> session_id = {},IP地址 = {},PC名称 = {},接口 = {},message = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session), message);
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void sendMessage(String sessionId, TextMessage message) {
WebSocketSession webSocketSession = sessionsMap.get(sessionId);
this.sendMessage(webSocketSession, message);
}
@Override
public void sendMessage(String sessionId, String message) {
WebSocketSession webSocketSession = sessionsMap.get(sessionId);
this.sendMessage(webSocketSession, message);
}
@Override
public void sendMessage(WebSocketSession session, TextMessage message) {
try {
log.info("WebSocket发送消息 ==> session_id = {},IP地址 = {},PC名称 = {},接口 = {},message = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session), message.getPayload());
session.sendMessage(message);
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void broadCast(String message) {
sessionsMap.forEach((sessionId, toSession) -> {
if (toSession.isOpen()) {
this.sendMessage(toSession, message);
}
});
}
@Override
public void broadCast(TextMessage message) {
sessionsMap.forEach((sessionId, toSession) -> {
if (toSession.isOpen()) {
this.sendMessage(toSession, message);
}
});
}
@Override
public void broadCastExclude(String message, String... sessionIds) {
sessionsMap.forEach((sessionId, toSession) -> {
if (!Arrays.asList(sessionIds).contains(sessionId)){
if (toSession.isOpen()) {
this.sendMessage(toSession, message);
}
}
});
}
@Override
public void broadCastExclude(TextMessage message, String... sessionIds) {
sessionsMap.forEach((sessionId, toSession) -> {
if (!Arrays.asList(sessionIds).contains(sessionId)){
if (toSession.isOpen()) {
this.sendMessage(toSession, message);
}
}
});
}
@Override
public void handleError(WebSocketSession session, Throwable error) {
log.error("WebSocket发生错误 ==> session_id = {},IP地址 = {},PC名称 = {},接口 = {},错误信息为 = {}", session.getId(), getHostAddress(session), getHostName(session), getRequest(session), error.getMessage());
}
@Override
public void closeWebSocket(WebSocketSession session) {
try {
if (session.isOpen()) {
session.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void closeWebSocket(String sessionId) {
WebSocketSession webSocketSession = sessionsMap.get(sessionId);
this.closeWebSocket(webSocketSession);
}
@Override
public Map<String, WebSocketSession> getSessions() {
return sessionsMap;
}
@Override
public int getConnectionCount() {
return connectionCount.get();
}
@Override
public String getRequest(WebSocketSession session) {
String path = session.getUri().getPath();
return StringUtils.replace(path, Constants.WebSocket.PATH_PREFIX + "/", "").trim();
}
@Override
public String getHostAddress(WebSocketSession session) {
return session.getRemoteAddress().getAddress().getHostAddress();
}
@Override
public String getHostName(WebSocketSession session) {
return session.getRemoteAddress().getHostName();
}
private FujWebSocket createFujWebSocket(WebSocketSession session) {
FujWebSocket fujWebSocket = new FujWebSocket(session.getId(), getRequest(session), getHostAddress(session), getHostName(session), session.getUri().toString());
fujWebSocket.setCreateTime(new Date());
fujWebSocket.setCreateBy(Constants.User.WMS);
return fujWebSocket;
}
}