一 序
接在上一篇《websocket入门系列:二Tomcat实现》
书上第11章写了个demo。我基于此修改下实现。netty的版本是4.17final
二 server端
package com.netty.websocket;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.stream.ChunkedWriteHandler;
public class WebSocketServer {
private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);
/**
* 保存所有WebSocket连接
*/
private Map<ChannelId, Channel> channelMap = new ConcurrentHashMap<ChannelId, Channel>();
private static final int MAX_CONTENT_LENGTH = 65536;
// ------------------------ member fields -----------------------
private String host; // 绑定的地址
private int port; // 绑定的端口
public WebSocketServer(String host, int port) {
this.host = host;
this.port = port;
}
public void run() throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
// 保存该Channel的引用
channelMap.put(ch.id(), ch);
ch.closeFuture().addListener(new ChannelFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception {
logger.info("channel close {}", future.channel());
// Channel 关闭后不再引用该Channel
channelMap.remove(future.channel().id());
}
});
pipeline.addLast("http-codec", new HttpServerCodec());
pipeline.addLast("aggregator", new HttpObjectAggregator(MAX_CONTENT_LENGTH));
pipeline.addLast("http-chunked", new ChunkedWriteHandler());
pipeline.addLast("handler", new WebSocketServerHandler(channelMap));
}
});
Channel ch = b.bind(port).sync().channel();
System.out.println("Web socket server started at port " + port
+ '.');
System.out
.println("Open your browser and navigate to http://localhost:"
+ port + '/');
ch.closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
public static void main(String[] args) throws Exception {
int port = 8080;
if (args.length > 0) {
try {
port = Integer.parseInt(args[0]);
} catch (NumberFormatException e) {
e.printStackTrace();
}
}
new WebSocketServer("127.0.0.1",port).run();
}
}
对应的
package com.netty.websocket;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.AttributeKey;
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
private static final Logger logger = LoggerFactory.getLogger(WebSocketServerHandler.class);
private static final String WEBSOCKET_UPGRADE = "websocket";
private static final String WEBSOCKET_CONNECTION = "Upgrade";
private final String WEBSOCKET_URI_ROOT ="ws://127.0.0.1:8080";
// handshaker attachment key
private static final AttributeKey<WebSocketServerHandshaker> ATTR_HANDSHAKER = AttributeKey.newInstance("ATTR_KEY_CHANNELID");
/**
* 保存所有WebSocket连接
*/
private Map<ChannelId, Channel> channelMap ;
public WebSocketServerHandler(Map channelMap){
this.channelMap = channelMap;
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
private void handleHttpRequest(ChannelHandlerContext ctx,
FullHttpRequest req) throws Exception {
if (isWebSocketUpgrade(req)) { // 该请求是不是websocket upgrade请求
logger.info("upgrade to websocket protocol");
String subProtocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(WEBSOCKET_URI_ROOT, subProtocols, false);
WebSocketServerHandshaker handshaker = factory.newHandshaker(req);
if (handshaker == null) {// 请求头不合法, 导致handshaker没创建成功
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
// 响应该请求
handshaker.handshake(ctx.channel(), req);
// 把handshaker 绑定给Channel, 以便后面关闭连接用
ctx.channel().attr(ATTR_HANDSHAKER).set(handshaker);// attach handshaker to this channel
}
return;
}
// TODO 忽略普通http请求
logger.info("ignoring normal http request");
}
private void handleWebSocketFrame(ChannelHandlerContext ctx,
WebSocketFrame frame) {
// text frame
if (frame instanceof TextWebSocketFrame) {
String text = ((TextWebSocketFrame) frame).text();
for (Channel ch : channelMap.values()) {
TextWebSocketFrame rspFrame = new TextWebSocketFrame(text);
logger.info("recieve TextWebSocketFrame from channel {}", ctx.channel());
// 发给其他所有channel
//
// if (ctx.channel().equals(ch)) {
// continue;
// }
ch.writeAndFlush(rspFrame);
logger.info("write text[{}] to channel {}", text, ch);
}
return;
}
// ping frame, 回复pong frame即可
if (frame instanceof PingWebSocketFrame) {
logger.info("recieve PingWebSocketFrame from channel {}", ctx.channel());
ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
return;
}
if (frame instanceof PongWebSocketFrame) {
logger.info("recieve PongWebSocketFrame from channel {}", ctx.channel());
return;
}
// close frame,
if (frame instanceof CloseWebSocketFrame) {
logger.info("recieve CloseWebSocketFrame from channel {}", ctx.channel());
WebSocketServerHandshaker handshaker = ctx.channel().attr(ATTR_HANDSHAKER).get();
if (handshaker == null) {
logger.error("channel {} have no HandShaker", ctx.channel());
return;
}
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
return;
}
// 剩下的是binary frame, 忽略
logger.warn("unhandle binary frame from channel {}", ctx.channel());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
cause.printStackTrace();
ctx.close();
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
// TODO Auto-generated method stub
logger.info("receive msg:"+msg.toString());
// 传统的HTTP接入
if (msg instanceof FullHttpRequest) {
handleHttpRequest(ctx, (FullHttpRequest) msg);
}
// WebSocket接入
else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
//三者与:1.GET? 2.Upgrade头 包含websocket字符串? 3.Connection头 包含 Upgrade字符串?
private boolean isWebSocketUpgrade(FullHttpRequest req) {
HttpHeaders headers = req.headers();
return req.method().equals(HttpMethod.GET)
&& headers.get(HttpHeaderNames.UPGRADE).contains(WEBSOCKET_UPGRADE)
&& headers.get(HttpHeaderNames.CONNECTION).contains(WEBSOCKET_CONNECTION);
}
}
这里面是主要业务逻辑的实现。对应websocket的协议,识别出握手过程,对于传输的数据帧类型进行不同的业务处理。
里面有一些细节。注意一个异常:
io.netty.util.IllegalReferenceCountException: refCnt: 0
三 测试
这里的测试,修改下地址:
web页面起了Tomcat。端口80,websocket的端口避免冲突:改为8080
web页面代码不再重复发了,参见上一篇,
分别打开火狐,Chrome浏览器。client模拟发一句:
package com.websocket.client;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
public class Test {
public static void main(String[] args) throws DeploymentException, IOException, InterruptedException {
WebSocketContainer ws = ContainerProvider.getWebSocketContainer();
String url = "ws://127.0.0.1:8080";
MyClient client = new MyClient();
Session session = ws.connectToServer(client, URI.create(url));
int turn = 0;
session.getBasicRemote().sendText("client send: " + turn);
Thread.sleep(1000);
new CountDownLatch(1).await();
}
}
Java的client输出:
I was accpeted by her!
客户端收到消息: client send: 0
客户端收到消息: chrome
客户端收到消息: hi,ff
页面的输出截屏:
*****************************************************************************
这里只是参照书上例子休简单的实现群发消息的demo.
还得深入学习netty,不然除了异常很难去修复。
参考:
http://lixiaohui.iteye.com/blog/2328183