场景
接到一个需求,需要使用webSocket和客户端交互。
需要保证连接安全,最初的想法是在webSocket握手时把access_token携带在请求头中,服务端在握手前进行鉴权处理,如果不合法则拒绝握手,合法则握手连接。
后面与客户端沟通,无法在握手连接时传递自定义的请求头,但可以传递一个Sec-WebSocket-Protocol 的请求头,可以将token放入这个头中。
在实际使用时,发现如果握手时携带了Sec-WebSocket-Protocol 那么握手成功后的响应头也必须携带,不然客户端会判定握手失败。
这就引出了问题,因为我在netty提供的WebSocketServerProtocolHandler 中没有找到提供握手连接后增加相应头的处理。
下面来看一下WebSocketServerProtocolHandler 实现握手的源码
一,WebSocketServerProtocolHandler 实现握手源码
在WebSocketServerProtocolHandler中有一个handlerAdded方法
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler(serverConfig));
}
if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation()));
}
}
在该方法中,在WebSocketServerProtocolHandler 之前添加了两个前置处理器
WebSocketServerProtocolHandshakeHandler / Utf8FrameValidator
主要是WebSocketServerProtocolHandshakeHandler ,该处理器就是用来处理webSocket握手连接的处理器。
继续去追踪该类的channelRead()方法看一下是如何处理握手连接的
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final HttpObject httpObject = (HttpObject) msg;
if (httpObject instanceof HttpRequest) {
final HttpRequest req = (HttpRequest) httpObject;
isWebSocketPath = isWebSocketPath(req);
// 判断是否为webSocket path,不是交友下一个处理器处理
if (!isWebSocketPath) {
ctx.fireChannelRead(msg);
return;
}
try {
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
serverConfig.subprotocols(), serverConfig.decoderConfig());
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
// Ensure we set the handshaker and replace this handler before we
// trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
// before we had a chance to replace it.
//
// See https://github.com/netty/netty/issues/9471.
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().remove(this);
//核心的握手方法
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
localHandshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
localHandshakePromise.trySuccess();
// Kept for compatibility
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new WebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
}
});
applyHandshakeTimeout();
}
} finally {
ReferenceCountUtil.release(req);
}
} else if (!isWebSocketPath) {
ctx.fireChannelRead(msg);
} else {
ReferenceCountUtil.release(msg);
}
}
在上述代码中可以看到handshaker.handshake(ctx.channel(), req); 该方法就是webSocket握手连接的处理,我们在追进去看一下
public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
HttpHeaders responseHeaders, final ChannelPromise promise) {
if (logger.isDebugEnabled()) {
logger.debug("{} WebSocket version {} server handshake", channel, version());
}
FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
ChannelPipeline p = channel.pipeline();
if (p.get(HttpObjectAggregator.class) != null) {
p.remove(HttpObjectAggregator.class);
}
if (p.get(HttpContentCompressor.class) != null) {
p.remove(HttpContentCompressor.class);
}
ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
final String encoderName;
if (ctx == null) {
// this means the user use an HttpServerCodec
ctx = p.context(HttpServerCodec.class);
if (ctx == null) {
promise.setFailure(
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
response.release();
return promise;
}
p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = ctx.name();
} else {
p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = p.context(HttpResponseEncoder.class).name();
p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
}
channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
ChannelPipeline p = future.channel().pipeline();
p.remove(encoderName);
promise.setSuccess();
} else {
promise.setFailure(future.cause());
}
}
});
return promise;
}
可以看到实际的握手方法是可以携带HttpHeaders的,并且也是通过httpHeader构建Response, 握手成功后响应给了客户端。
也就是说我们只需要调用带参(HttpHeader)的方法去进行握手处理即可完成该功能。
但是很可惜netty提供的WebSocketServerProtocolHandshakeHandler 握手并没有提供一些扩展,并且该类不是public的,无法做到继承重写握手的处理来解决,也就是说需要自己去实现http升级到websocket,完成握手连接操作
自实现webSocket握手的handler
package com.snr.angel.remoting.handler;
import cn.hutool.core.util.StrUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.*;
/**
* @Author xzq
* @Version 1.0.0
* @Date 2023/9/6 09:30
*/
public class WebSocketServerProtocolHandshakeHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
private static final String WEBSOCKET_PATH = "/test";
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
// 检查HTTP请求是否是WebSocket握手请求
if (req.decoderResult().isSuccess() && "websocket".equals(req.headers().get("Upgrade"))) {
if (!req.uri().equals(WEBSOCKET_PATH)) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
return;
}
// 构建WebSocket握手处理器
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(req), null, true);
WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
// 如果不支持WebSocket版本,返回HTTP 405错误
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
String protocol = req.headers().get("Sec-WebSocket-Protocol");
if (StrUtil.isNotEmpty(protocol) && protocol.equals("xzq_test")) {
// 创建一个 HttpHeaders 实例
HttpHeaders headers = new DefaultHttpHeaders();
// 设置单个头部信息
headers.set("Sec-WebSocket-Protocol", protocol);
// 握手成功,升级协议为WebSocket
handshaker.handshake(ctx.channel(), req, headers, ctx.newPromise());
// 添加WebSocket消息处理器
ctx.pipeline().replace(this, "websocketHandler", new WebSocketFrameAggregator(65536));
}else{
// 如果不支持WebSocket版本,返回HTTP 405错误
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
}
}
} else {
// 非WebSocket握手请求,关闭连接
ctx.fireChannelRead(req.retain());
ctx.channel().pipeline().fireUserEventTriggered(new CloseWebSocketFrame(1000, "Unsupported protocol"));
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
// WebSocket握手完成,添加连接关闭的监听器
ctx.channel().closeFuture().addListener(future -> {
if (future.isSuccess()) {
// 连接成功关闭
System.out.println("WebSocket connection closed.");
}
});
}
super.userEventTriggered(ctx, evt);
}
private String getWebSocketLocation(FullHttpRequest req) {
return "ws://" + req.headers().get("Host") + WEBSOCKET_PATH;
}
}