原文链接:https://www.dubby.cn/detail.html?id=9104
获取代码 https://github.com/dubby1994/netty-study/tree/master
1.依赖
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.28.Final</version>
</dependency>
2.实现
WebSocketServer.java
:
public final class WebSocketServer {
static final int PORT = 8080;
public static void main(String[] args) throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new WebSocketServerInitializer());
Channel ch = b.bind(PORT).sync().channel();
ch.closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
WebSocketServerInitializer.java
:
public class WebSocketServerInitializer extends ChannelInitializer<SocketChannel> {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(65536));
pipeline.addLast(new WebSocketServerHandler());
}
}
WebSocketServerHandler.java
:
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
private static final ChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
private static final String WEBSOCKET_PATH = "/websocket";
private WebSocketServerHandshaker handshaker;
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
channels.add(ctx.channel());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
channels.remove(ctx.channel());
}
@Override
public void channelRead0(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof FullHttpRequest) {
handleHttpRequest(ctx, (FullHttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
// Handle a bad request.
if (!req.decoderResult().isSuccess()) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
return;
}
// Allow only GET methods.
if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
// Handshake
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
handshaker.handshake(ctx.channel(), req);
}
}
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
if (frame instanceof CloseWebSocketFrame) {
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
return;
}
if (frame instanceof PingWebSocketFrame) {
ctx.write(new PongWebSocketFrame(frame.content().retain()));
return;
}
if (frame instanceof TextWebSocketFrame) {
System.out.println(ctx.name() + "\t" + ((TextWebSocketFrame) frame).text());
//广播给所有客户端
channels.writeAndFlush(new TextWebSocketFrame(((TextWebSocketFrame) frame).text()));
return;
}
if (frame instanceof BinaryWebSocketFrame) {
ctx.write(frame.retain());
}
}
private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
if (res.status().code() != HttpResponseStatus.OK.code()) {
ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
res.content().writeBytes(buf);
buf.release();
HttpUtil.setContentLength(res, res.content().readableBytes());
}
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!HttpUtil.isKeepAlive(req) || res.status().code() != HttpResponseStatus.OK.code()) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}
private static String getWebSocketLocation(FullHttpRequest req) {
String location = req.headers().get(HttpHeaderNames.HOST) + WEBSOCKET_PATH;
return "ws://" + location;
}
}
这里你也可以自己搞个集合来维护channel,然后自己搞个线程池来执行,不过Netty已经帮我想到了,所以提供了ChannelGroup来维护channel。