<!-- 添加WebSocket依赖,实现socket通信 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
package com.szch3.h1s.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* 开启WebSocket支持
* @author DoubleC
*
*/
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebScoketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
package com.szch3.h1s.config;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import net.sf.json.JSONObject;
@ServerEndpoint("/websocket/{sid}")
@Component
public class WebSocketServer {
private static final Logger log=LoggerFactory.getLogger(WebSocketServer.class);
/**
* 当前在线连接数
*/
private static AtomicInteger onlineCount=new AtomicInteger(0);
/**
* 用来存放每个客户端对应的WebSocketServer对象
*/
//private static ConcurrenHashMap<String,WebSocketServer> webSocketMap=new ConcurrentHashMap<>();
///private static CopyOnWriteArraySet<WebSocketServer> webSocketSet=new CopyOnWriteArraySet<WebSocketServer>();
private static ConcurrentHashMap<String, WebSocketServer> webSocketMap =new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 接受sid
*/
private String sid="";
@OnOpen
public void onOpen(Session session,@PathParam("sid")String sid) {
this.session=session;
this.sid=sid;
if(webSocketMap.containsKey(sid)) {
webSocketMap.remove(sid);
webSocketMap.put(sid, this);
}else {
webSocketMap.put(sid, this);
addOnlineCount();
}
log.info("用户连接:"+sid+",当前在线人数为:"+getOnlineCount());
try {
sendMessage("连接成功!");
} catch (Exception e) {
// TODO: handle exception
log.error("用户:"+sid+",网络异常!!!!!!");
}
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
if(webSocketMap.containsKey(sid)) {
webSocketMap.remove(sid);
subOnlineCount();
}
log.info("用户退出:" + sid + ",当前在线人数为:" + getOnlineCount());
}
/**
* 收到客户端消息后调用的方法
* @param message
* @param session
*/
@OnMessage
public void onMessage(String message,Session session) {
log.info("用户消息:"+sid+",报文:"+message);
if(!StringUtils.isEmpty(message)) {
try {
JSONObject jsonObject = JSONObject.fromObject(message);
jsonObject.put("userId", this.sid);
System.out.println(jsonObject.toString());
String toUserId=jsonObject.getString("userId");
System.out.println(toUserId);
if(!StringUtils.isEmpty(toUserId)&& webSocketMap.containsKey(toUserId)) {
for (int i = 0; i <2; i++) {
webSocketMap.get(toUserId).sendMessage(jsonObject.toString());
}
}else {
log.error("请求的userId:"+toUserId+"不在该服务器上");
}
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}
}
/**
* 发生错误调用
* @param session
* @param error
*/
@OnError
public void onError(Session session,Throwable error) {
log.error("用户错误:" + this.sid + ",原因:" + error.getMessage());
error.printStackTrace();
}
/**
* 实现服务器主动推送
*
* @param message
* @throws IOException
*/
private void sendMessage(String message) throws IOException {
// TODO Auto-generated method stub
this.session.getBasicRemote().sendText(message);
}
/**
* 在线人数减一
*/
private static synchronized void subOnlineCount() {
// TODO Auto-generated method stub
WebSocketServer.onlineCount.getAndDecrement();
}
private static synchronized AtomicInteger getOnlineCount() {
// TODO Auto-generated method stub
return onlineCount;
}
/**
* 在线人数加一
*/
public static synchronized void addOnlineCount() {
WebSocketServer.onlineCount.getAndIncrement();
}
}
测试