netty实现webSocket握手处理并添加响应头

场景

接到一个需求,需要使用webSocket和客户端交互。
需要保证连接安全,最初的想法是在webSocket握手时把access_token携带在请求头中,服务端在握手前进行鉴权处理,如果不合法则拒绝握手,合法则握手连接。

后面与客户端沟通,无法在握手连接时传递自定义的请求头,但可以传递一个Sec-WebSocket-Protocol 的请求头,可以将token放入这个头中。
在实际使用时,发现如果握手时携带了Sec-WebSocket-Protocol 那么握手成功后的响应头也必须携带,不然客户端会判定握手失败。

这就引出了问题,因为我在netty提供的WebSocketServerProtocolHandler 中没有找到提供握手连接后增加相应头的处理。

下面来看一下WebSocketServerProtocolHandler 实现握手的源码

一,WebSocketServerProtocolHandler 实现握手源码

在WebSocketServerProtocolHandler中有一个handlerAdded方法

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        ChannelPipeline cp = ctx.pipeline();
        if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
            // Add the WebSocketHandshakeHandler before this one.
            cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
                    new WebSocketServerProtocolHandshakeHandler(serverConfig));
        }
        if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
            // Add the UFT8 checking before this one.
            cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
                    new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation()));
        }
    }

在该方法中,在WebSocketServerProtocolHandler 之前添加了两个前置处理器
WebSocketServerProtocolHandshakeHandler / Utf8FrameValidator
主要是WebSocketServerProtocolHandshakeHandler ,该处理器就是用来处理webSocket握手连接的处理器。
继续去追踪该类的channelRead()方法看一下是如何处理握手连接的

    @Override
    public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
        final HttpObject httpObject = (HttpObject) msg;

        if (httpObject instanceof HttpRequest) {
            final HttpRequest req = (HttpRequest) httpObject;
            isWebSocketPath = isWebSocketPath(req);
            // 判断是否为webSocket path,不是交友下一个处理器处理
            if (!isWebSocketPath) {
                ctx.fireChannelRead(msg);
                return;
            }

            try {
                final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                        getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
                        serverConfig.subprotocols(), serverConfig.decoderConfig());
                final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
                final ChannelPromise localHandshakePromise = handshakePromise;
                if (handshaker == null) {
                    WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
                } else {
                    // Ensure we set the handshaker and replace this handler before we
                    // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
                    // before we had a chance to replace it.
                    //
                    // See https://github.com/netty/netty/issues/9471.
                    WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
                    ctx.pipeline().remove(this);
					//核心的握手方法
                    final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
                    handshakeFuture.addListener(new ChannelFutureListener() {
                        @Override
                        public void operationComplete(ChannelFuture future) {
                            if (!future.isSuccess()) {
                                localHandshakePromise.tryFailure(future.cause());
                                ctx.fireExceptionCaught(future.cause());
                            } else {
                                localHandshakePromise.trySuccess();
                                // Kept for compatibility
                                ctx.fireUserEventTriggered(
                                        WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
                                ctx.fireUserEventTriggered(
                                        new WebSocketServerProtocolHandler.HandshakeComplete(
                                                req.uri(), req.headers(), handshaker.selectedSubprotocol()));
                            }
                        }
                    });
                    applyHandshakeTimeout();
                }
            } finally {
                ReferenceCountUtil.release(req);
            }
        } else if (!isWebSocketPath) {
            ctx.fireChannelRead(msg);
        } else {
            ReferenceCountUtil.release(msg);
        }
    }

在上述代码中可以看到handshaker.handshake(ctx.channel(), req); 该方法就是webSocket握手连接的处理,我们在追进去看一下

    public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
                                            HttpHeaders responseHeaders, final ChannelPromise promise) {

        if (logger.isDebugEnabled()) {
            logger.debug("{} WebSocket version {} server handshake", channel, version());
        }
        FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
        ChannelPipeline p = channel.pipeline();
        if (p.get(HttpObjectAggregator.class) != null) {
            p.remove(HttpObjectAggregator.class);
        }
        if (p.get(HttpContentCompressor.class) != null) {
            p.remove(HttpContentCompressor.class);
        }
        ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
        final String encoderName;
        if (ctx == null) {
            // this means the user use an HttpServerCodec
            ctx = p.context(HttpServerCodec.class);
            if (ctx == null) {
                promise.setFailure(
                        new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
                response.release();
                return promise;
            }
            p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
            p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
            encoderName = ctx.name();
        } else {
            p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());

            encoderName = p.context(HttpResponseEncoder.class).name();
            p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
        }
        channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if (future.isSuccess()) {
                    ChannelPipeline p = future.channel().pipeline();
                    p.remove(encoderName);
                    promise.setSuccess();
                } else {
                    promise.setFailure(future.cause());
                }
            }
        });
        return promise;
    }

可以看到实际的握手方法是可以携带HttpHeaders的,并且也是通过httpHeader构建Response, 握手成功后响应给了客户端。
也就是说我们只需要调用带参(HttpHeader)的方法去进行握手处理即可完成该功能。
但是很可惜netty提供的WebSocketServerProtocolHandshakeHandler 握手并没有提供一些扩展,并且该类不是public的,无法做到继承重写握手的处理来解决,也就是说需要自己去实现http升级到websocket,完成握手连接操作

自实现webSocket握手的handler

package com.snr.angel.remoting.handler;

import cn.hutool.core.util.StrUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.*;

/**
 * @Author xzq
 * @Version 1.0.0
 * @Date 2023/9/6 09:30
 */
public class WebSocketServerProtocolHandshakeHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
    private static final String WEBSOCKET_PATH = "/test";
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
        // 检查HTTP请求是否是WebSocket握手请求
        if (req.decoderResult().isSuccess() && "websocket".equals(req.headers().get("Upgrade"))) {


            if (!req.uri().equals(WEBSOCKET_PATH)) {
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
                return;
            }

            // 构建WebSocket握手处理器
            WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                    getWebSocketLocation(req), null, true);
            WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);

            if (handshaker == null) {
                // 如果不支持WebSocket版本,返回HTTP 405错误
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            } else {
                String protocol = req.headers().get("Sec-WebSocket-Protocol");

                if (StrUtil.isNotEmpty(protocol) && protocol.equals("xzq_test")) {
                    // 创建一个 HttpHeaders 实例
                    HttpHeaders headers = new DefaultHttpHeaders();
                    // 设置单个头部信息
                    headers.set("Sec-WebSocket-Protocol", protocol);

                    // 握手成功,升级协议为WebSocket
                    handshaker.handshake(ctx.channel(), req, headers, ctx.newPromise());

                    // 添加WebSocket消息处理器
                    ctx.pipeline().replace(this, "websocketHandler", new WebSocketFrameAggregator(65536));
                }else{
                    // 如果不支持WebSocket版本,返回HTTP 405错误
                    WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
                }

            }
        } else {
            // 非WebSocket握手请求,关闭连接
            ctx.fireChannelRead(req.retain());
            ctx.channel().pipeline().fireUserEventTriggered(new CloseWebSocketFrame(1000, "Unsupported protocol"));
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
            // WebSocket握手完成,添加连接关闭的监听器
            ctx.channel().closeFuture().addListener(future -> {
                if (future.isSuccess()) {
                    // 连接成功关闭
                    System.out.println("WebSocket connection closed.");
                }
            });
        }

        super.userEventTriggered(ctx, evt);
    }
    private String getWebSocketLocation(FullHttpRequest req) {
        return "ws://" + req.headers().get("Host") + WEBSOCKET_PATH;
    }
}

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值