由于websocket是长连接,session保持在一个server中,所以在不同server在使用websocket推送消息时就需要获取对应的session进行推送,在分布式系统中就无法获取到所有session,这里就需要使用一个中间件将消息推送到各个系统中,在这里使用的redis,使用redis的sub/pub功能。
在项目中,我使用的是spring-boot-starter-data-redis进行redis的连接。websocket发送到不同的用户,在websocket连接成功后,将usercode,websocketsession放入Map集合中。如果需要发送消息给相应的websocket连接,则发送到redis的对应频道上;同时系统订阅该频道,对监听得到的消息进行处理,得到相应的websocket的usercode,获取session,再使用session进行websocket的消息推送。
websocket代码
public class OrderHandler extends TextWebSocketHandler{
public static final HashMap<String, WebSocketSession> USER_SESSION_MAP;
static {
userSessionMap = new ConcurrentHashMap<>();
}
/**
* 连接成功时候,会触发UI上onopen方法
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
System.out.println("connect to the websocket success......");
Map<String, Object> attributes = session.getAttributes();
String userCode = (String) attributes.get("userCode");
if(userCode != null) {
userSessionMap.put(userCode, session);
}
}
/**
* 在UI在用js调用websocket.send()时候,会调用该方法
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
sendMessageToUsers(message);
}
/**
* 给某个用户发送消息
*
* @param userName
* @param message
*/
public void sendMessageToUser(String userName, TextMessage message) {
}
/**
* 给所有在线用户发送消息
*
* @param message
*/
public void sendMessageToUsers(TextMessage message) {
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
if (session.isOpen()) {
session.close();
}
Map<String, Object> attributes = session.getAttributes();
String userCode = (String) attributes.get("userCode");
if(userCode != null) {
userSessionMap.remove(userCode);
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
Map<String, Object> attributes = session.getAttributes();
String userCode = (String) attributes.get("userCode");
if(userCode != null) {
userSessionMap.remove(userCode);
}
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
redis的配置
@Configuration
public class RedisConfig {
private Executor redisTaskExecutor;
@Autowired(
required = false //该处理监听的线程池不是必须的,如果不自定义默认将使用SimpleAsyncTaskExecutor线程池
)
@Qualifier("springSessionRedisTaskExecutor")
public void setRedisTaskExecutor(Executor redisTaskExecutor) {
this.redisTaskExecutor = redisTaskExecutor;
}
@Bean
public RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory,
MessageListenerAdapter XXXListenerAdapter,
MessageListenerAdapter YYYListenerAdapter) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
if(this.redisTaskExecutor != null) {
container.setTaskExecutor(this.redisTaskExecutor);
}
container.addMessageListener(XXXListenerAdapter, new PatternTopic("XXX"));
container.addMessageListener(YYYListenerAdapter, new PatternTopic("YYY"));
return container;
}
@Bean
public MessageListenerAdapter orderCodeListenerAdapter(XXXReceiver receiver) {
//这个地方 是给messageListenerAdapter 传入一个消息接受的处理器,利用反射的方法调用“receiveMessage”
//也有好几个重载方法,这边默认调用处理器的方法 叫handleMessage 可以自己到源码里面看
return new MessageListenerAdapter(receiver, "receiveMessage");
}
@Bean
public MessageListenerAdapter orderCloseListenerAdapter(YYYReceiver receiver) {
return new MessageListenerAdapter(receiver, "receiveMessage");
}
@Bean
public StringRedisTemplate stringRedisTemplate(RedisConnectionFactory connectionFactory) {
return new StringRedisTemplate(connectionFactory);
}
/**
* 这里的之采用Jackson2JsonRedisSerializer进行序列化,将发送到
* redis的对象序列化
*/
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(connectionFactory);
redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(RedisMessageVO.class));
return redisTemplate;
}
@Bean
public ThreadPoolTaskExecutor springSessionRedisTaskExecutor(){
ThreadPoolTaskExecutor springSessionRedisTaskExecutor = new ThreadPoolTaskExecutor();
springSessionRedisTaskExecutor.setCorePoolSize(4);
springSessionRedisTaskExecutor.setMaxPoolSize(8);
springSessionRedisTaskExecutor.setKeepAliveSeconds(10);
springSessionRedisTaskExecutor.setQueueCapacity(1000);
springSessionRedisTaskExecutor.setThreadNamePrefix("Spring session redis executor thread: ");
return springSessionRedisTaskExecutor;
}
}
这里监听redis的三个频道,使用三个方法处理三个频道的消息。
@Component
public class XXXReceiver {
public void receiveMessage(String message) throws IOException{
System.out.println("redis message to phone:"+message);
Jackson2JsonRedisSerializer<RedisMessageVO> redisSerializer = new Jackson2JsonRedisSerializer<>(RedisMessageVO.class);
RedisMessageVO messageVO = redisSerializer.deserialize(message.getBytes());
String code = messageVO.getCode();
WebSocketSession session = OrderHandler.USER_SESSION_MAP.get(code);
if(session != null) {
TextMessage tm = new TextMessage(messageVO.getMessage());
session.sendMessage(tm);
}
}
}
该方法是处理订阅的redis消息,反序列化得到消息vo,根据里面的标识去websocket session的Map集合中获取session,如果能够获取到,则使用该session推送消息,如果没有获取到则说明websocket连接不再这个服务里。