1、第一步需要在pom.xml中添加 websocket 依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
2、第二步添加 ServerEndpointExporter Bean配置
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
3、第三步新建 WebSocketServer 对外服务
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.springframework.stereotype.Component;
import com.alibaba.fastjson.JSONObject;
import com.dbappsecurity.utils.concurrent.ThreadPoolUtil;
import com.dbappsecurity.utils.domain.WebSocketMsgDTO;
import lombok.extern.slf4j.Slf4j;
@ServerEndpoint(value = "/websocket/{userId}")
@Component
@Slf4j
public class WebSocketServer {
private static AtomicLong count = new AtomicLong(0);
public static ConcurrentHashMap<Long, WebSocketServer> map = new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 连接建立成功调用的方法*/
@OnOpen
public void onOpen(@PathParam("userId") Long userId, Session session) {
// TODO 判断 userId 是否为已登录用户,如果不是,则说明发生入侵
this.session = session;
map.put(userId, this);
addOnlineCount();
log.info("有新连接加入: [{}], 当前在线人数为 [{}]", userId, getOnlineCount());
try {
session.getBasicRemote().sendText("Connected.");
} catch (IOException e) {
log.error("websocket IO异常", e);
}
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
Iterator<Map.Entry<Long, WebSocketServer>> it = map.entrySet().iterator();
Long userId = null;
while(it.hasNext()){
Map.Entry<Long, WebSocketServer> entry = it.next();
if(entry.getValue().equals(this)){
userId = entry.getKey();
it.remove();
}
}
// 在线数减 1
subOnlineCount();
log.info("有一连接关闭 : [{}], 当前在线人数为: [{}]" , userId, getOnlineCount());
}
/**
* 收到客户端消息后调用的方法
* @param message 客户端发送过来的消息
* */
@OnMessage
public void onMessage(String message, Session session) {
log.info("来自客户端的消息:{}", message);
}
/**
* 发生错误时
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("发生错误", error);
}
/**
* 向客户端发送消息
* @param msg
* @throws IOException
*/
public static void sendMessage(final WebSocketMsgDTO msg){
ThreadPoolUtil.execute(new Runnable() {
@Override
public void run() {
WebSocketServer webSocketServer = map.get(msg.getUserId());
if(webSocketServer != null){
try {
webSocketServer.session.getBasicRemote().sendText(JSONObject.toJSONString(msg));
} catch (IOException e) {
log.error("推送至 websocket 失败:", e);
}
}
}
});
}
/**
* 向所有客户端发送消息
* @param socketMsgDTO
* @throws Exception
*/
public static void sendMessageAll(final WebSocketMsgDTO socketMsgDTO) {
ThreadPoolUtil.execute(new Runnable() {
@Override
public void run() {
Iterator<Map.Entry<Long, WebSocketServer>> it = map.entrySet().iterator();
while(it.hasNext()){
Map.Entry<Long, WebSocketServer> entry = it.next();
try {
entry.getValue().session.getBasicRemote().sendText(JSONObject.toJSONString(socketMsgDTO));
} catch (IOException e) {
log.error("推送至 websocket 失败:", e);
}
}
}
});
}
public static synchronized long getOnlineCount() {
return count.get();
}
public static synchronized void addOnlineCount() {
count.incrementAndGet();
}
public static synchronized void subOnlineCount() {
count.decrementAndGet();
}
}
4、为了性能,笔者在此处使用了线程池 ThreadPoolUtil
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j;
/**
* 线程池工具类
* Created by babylon on 2016/12/4.
*/
@Component
@Slf4j
public class ThreadPoolUtil implements InitializingBean, DisposableBean {
private static volatile ThreadPoolUtil threadPool;
private static ThreadPoolExecutor executor=null;
//线程池的基础参数 实际使用可写入到配置文件中
/**
* 核心池的大小 运行线程的最大值 当线程池中的线程数目达到corePoolSize后,就会把多余的任务放到缓存队列当中;
*/
private int corePoolSize = 5;
/**
* 创建线程最大值
*/
private int maximumPoolSize = 15;
/**
* 线程没有执行任务时 被保留的最长时间 超过这个时间就会被销毁 直到线程数等于 corePoolSize
*/
private long keepAliveTime = 10;
/**
* 等待线程池任务执行结束超时时间
*/
private long timeout = 10;
/** 参数keepAliveTime的时间单位,有7种取值,在TimeUnit类中有7种静态属性:
TimeUnit.DAYS; 天
TimeUnit.HOURS; 小时
TimeUnit.MINUTES; 分钟
TimeUnit.SECONDS; 秒
TimeUnit.MILLISECONDS; 毫秒
TimeUnit.MICROSECONDS; 微妙
TimeUnit.NANOSECONDS; 纳秒***/
private TimeUnit unit= TimeUnit.SECONDS;
/**
* 用来储存等待中的任务的容器
*
* 几种选择:
* ArrayBlockingQueue;
* LinkedBlockingQueue;
* SynchronousQueue;
* 区别太罗嗦请百度 http://blog.csdn.net/mn11201117/article/details/8671497
*/
private LinkedBlockingQueue workQueue=new LinkedBlockingQueue<Runnable>();
@Override
public void afterPropertiesSet() throws Exception {
init();
}
/**
* 单例
* @return
*/
public static ThreadPoolUtil init(){
if(threadPool == null){
synchronized (ThreadPoolUtil.class){
if(threadPool == null){
threadPool = new ThreadPoolUtil();
}
}
}
return threadPool;
}
/**
* 私有构造方法
*/
private ThreadPoolUtil(){
// 实现线程池
executor=new ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, new CustomThreadFactory());
log.info("线程池初始化成功");
}
/**
* 线程池获取方法
* @return
*/
public static ThreadPoolExecutor getExecutor() {return executor;}
/**
* 准备执行 抛入线程池
* @param t
*/
public static void execute(Thread t){
executor.execute(t);
}
public static void execute(Runnable t){ executor.execute(t);}
public static int getQueueSize(){
return executor.getQueue().size();
}
/**
* 异步提交返回 Future
* Future.get()可获得返回结果
* @return
*/
public static Future<?> submit(Runnable t){return executor.submit(t);}
/**
* 异步提交返回 Future
* Future.get()可获得返回结果
* @return
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
public static Future<?> submit(Callable t){return getExecutor().submit(t);}
/**
* 销毁线程池
* */
public static void shutdown(){
getExecutor().shutdown();
}
/**
* 阻塞,直到线程池里所有任务结束
*/
public static void awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
log.info("Thread pool ,awaitTermination started, please wait till all the jobs complete.");
executor.awaitTermination(timeout, unit);
}
@Override
public void destroy() throws Exception {
shutdown();
}
private class CustomThreadFactory implements ThreadFactory {
private AtomicInteger count = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
String threadName = ThreadPoolUtil.class.getSimpleName() + count.addAndGet(1);
t.setName(threadName);
return t;
}
}
}
5、前端测试页面
<!DOCTYPE HTML>
<html>
<head>
<title>My WebSocket</title>
</head>
<body>
Welcome<br/>
<input id="text" type="text" /><button onclick="send()">Send</button> <button onclick="closeWebSocket()">Close</button>
<div id="message">
</div>
</body>
<script type="text/javascript">
var websocket = null;
//判断当前浏览器是否支持WebSocket
if('WebSocket' in window){
websocket = new WebSocket("ws://localhost:8088/websocket/1");
}
else{
alert('Not support websocket')
}
//连接发生错误的回调方法
websocket.onerror = function(){
setMessageInnerHTML("error");
};
//连接成功建立的回调方法
websocket.onopen = function(event){
setMessageInnerHTML("open");
}
//接收到消息的回调方法
websocket.onmessage = function(event){
setMessageInnerHTML(event.data);
}
//连接关闭的回调方法
websocket.onclose = function(){
setMessageInnerHTML("close");
}
//监听窗口关闭事件,当窗口关闭时,主动去关闭websocket连接,防止连接还没断开就关闭窗口,server端会抛异常。
window.onbeforeunload = function(){
websocket.close();
}
//将消息显示在网页上
function setMessageInnerHTML(innerHTML){
document.getElementById('message').innerHTML += innerHTML + '<br/>';
}
//关闭连接
function closeWebSocket(){
websocket.close();
}
//发送消息
function send(){
var message = document.getElementById('text').value;
websocket.send(message);
}
</script>
</html>
运行效果截图: