利用websocket的PROTOCOL协议进行动态鉴权
new WebSocket(url,[token])
- 重写默认的(直接拷贝)
spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java
- 重写的完整代码
import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.PRESERVE_HOST_HEADER_ATTRIBUTE;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.containsEncodedParts;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.isAlreadyRouted;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.setAlreadyRouted;
import static org.springframework.util.StringUtils.commaDelimitedListToStringArray;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import com.voicecomm.common.core.constant.GatewayConstant;
import com.voicecomm.common.core.utils.JwtToken;
import com.voicecomm.common.core.utils.UserRedisUtils;
import reactor.core.publisher.Mono;
/**
* @author Spencer Gibb
* @author Nikita Konev
*/
@Component
public class GlobalWebsocketRoutingFilter implements GlobalFilter, Ordered {
private static final Log log = LogFactory.getLog(GlobalWebsocketRoutingFilter.class);
private final WebSocketClient webSocketClient;
private final WebSocketService webSocketService;
private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;
// do not use this headersFilters directly, use getHeadersFilters() instead.
private volatile List<HttpHeadersFilter> headersFilters;
@Autowired
private UserRedisUtils userRedisUtils;
public GlobalWebsocketRoutingFilter(WebSocketClient webSocketClient, WebSocketService webSocketService,
ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
this.webSocketClient = webSocketClient;
this.webSocketService = webSocketService;
this.headersFiltersProvider = headersFiltersProvider;
}
/* for testing */
static String convertHttpToWs(String scheme) {
scheme = scheme.toLowerCase();
return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme;
}
@Override
public int getOrder() {
// Before NettyRoutingFilter since this routes certain http requests
return Ordered.LOWEST_PRECEDENCE - 2;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
changeSchemeIfIsWebSocketUpgrade(exchange);
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme();
if (isAlreadyRouted(exchange) || (!"ws".equals(scheme) && !"wss".equals(scheme))) {
return chain.filter(exchange);
}
if (!authorizationWebsocket(exchange)) {
exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
return exchange.getResponse().setComplete();
}
setAlreadyRouted(exchange);
HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange);
HttpHeaders headers = exchange.getRequest().getHeaders();
List<String> protocols = headers.get(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL);
if (protocols != null) {
protocols = headers.get(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL).stream()
.flatMap(header -> Arrays.stream(commaDelimitedListToStringArray(header))).map(String::trim)
.collect(Collectors.toList());
}
return this.webSocketService.handleRequest(exchange,
new ProxyWebSocketHandler(requestUrl, this.webSocketClient, filtered, protocols));
}
/* for testing */ List<HttpHeadersFilter> getHeadersFilters() {
if (this.headersFilters == null) {
this.headersFilters = this.headersFiltersProvider.getIfAvailable(ArrayList::new);
// remove host header unless specifically asked not to
headersFilters.add((headers, exchange) -> {
HttpHeaders filtered = new HttpHeaders();
filtered.addAll(headers);
filtered.remove(HttpHeaders.HOST);
boolean preserveHost = exchange.getAttributeOrDefault(PRESERVE_HOST_HEADER_ATTRIBUTE, false);
if (preserveHost) {
String host = exchange.getRequest().getHeaders().getFirst(HttpHeaders.HOST);
filtered.add(HttpHeaders.HOST, host);
}
return filtered;
});
headersFilters.add((headers, exchange) -> {
HttpHeaders filtered = new HttpHeaders();
headers.entrySet().stream().filter(entry -> !entry.getKey().toLowerCase().startsWith("sec-websocket"))
.forEach(header -> filtered.addAll(header.getKey(), header.getValue()));
return filtered;
});
}
return this.headersFilters;
}
static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
// Check the Upgrade
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme().toLowerCase();
String upgrade = exchange.getRequest().getHeaders().getUpgrade();
// change the scheme if the socket client send a "http" or "https"
if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) {
String wsScheme = convertHttpToWs(scheme);
boolean encoded = containsEncodedParts(requestUrl);
URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme).build(encoded).toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
if (log.isTraceEnabled()) {
log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
}
}
}
private static class ProxyWebSocketHandler implements WebSocketHandler {
private final WebSocketClient client;
private final URI url;
private final HttpHeaders headers;
private final List<String> subProtocols;
ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
this.client = client;
this.url = url;
this.headers = headers;
this.subProtocols = protocols;
}
@Override
public List<String> getSubProtocols() {
return this.subProtocols;
}
@Override
public Mono<Void> handle(WebSocketSession session) {
// pass headers along so custom headers can be sent through
return client.execute(url, this.headers, new WebSocketHandler() {
@Override
public Mono<Void> handle(WebSocketSession proxySession) {
Mono<Void> serverClose = proxySession.closeStatus().flatMap(session::close);
// Mono<Void> proxyClose = session.closeStatus().flatMap(proxySession::close);
// Use retain() for Reactor Netty
Mono<Void> proxySessionSend = proxySession
.send(session.receive().doOnNext(WebSocketMessage::retain));
// .log("proxySessionSend", Level.FINE);
Mono<Void> serverSessionSend = session
.send(proxySession.receive().doOnNext(WebSocketMessage::retain));
// .log("sessionSend", Level.FINE);
return Mono.zip(proxySessionSend, serverSessionSend, serverClose).then();
}
/**
* Copy subProtocols so they are available downstream.
*
* @return
*/
@Override
public List<String> getSubProtocols() {
return ProxyWebSocketHandler.this.subProtocols;
}
});
}
}
// 鉴权websocket,子协议中的token
private boolean authorizationWebsocket(ServerWebExchange exchange) {
ServerHttpRequest request = exchange.getRequest();
String token = request.getHeaders().getFirst(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL);
if (StringUtils.isNotBlank(token)) {
String userId = JwtToken.getPrimaryKey(token);
if (StringUtils.isBlank(userId)) {
return false;
}
Object user = userRedisUtils.get(userId, token);
return user != null;
} else {
return false;
}
}
}
- 其他服务中怎么根据动态协议进行配合,代码参考
@Slf4j
@ChannelHandler.Sharable
public class HttpHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
@Override
public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
if (msg instanceof FullHttpRequest) {
log.info("http请求,第一次握手:协议升级");
handleHttpRequest(ctx, (FullHttpRequest) msg);
log.info("------------------" + NoticeChannelManager.count());
ctx.fireChannelRead(msg.retain());
} else {
ctx.close();
}
}
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
if (!req.decoderResult().isSuccess() || (!"websocket".equals(req.headers().get("Upgrade")))) {
System.err.println("// 错误方式连接,拒绝,返回");
sendHttpResponse(ctx, req,
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
return;
}
// 握手实例管理
String token = req.headers().get("Sec-WebSocket-Protocol");
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory("ws:/" + ctx.channel(), token,
false);
ctx.channel().attr(NoticeChannelAttr.USER_ID).set(JwtToken.getPrimaryKey(token));
NoticeChannelManager.add(ctx.channel());
WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
handshaker.handshake(ctx.channel(), req);
}
}
}
通过netty实现websocket