spring cloud gateway中对websocket进行动态鉴权

利用websocket的PROTOCOL协议进行动态鉴权

  1. new WebSocket(url,[token])
  2. 重写默认的(直接拷贝)
spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java
  1. 重写的完整代码
import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.PRESERVE_HOST_HEADER_ATTRIBUTE;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.containsEncodedParts;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.isAlreadyRouted;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.setAlreadyRouted;
import static org.springframework.util.StringUtils.commaDelimitedListToStringArray;

import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;

import com.voicecomm.common.core.constant.GatewayConstant;
import com.voicecomm.common.core.utils.JwtToken;
import com.voicecomm.common.core.utils.UserRedisUtils;

import reactor.core.publisher.Mono;

/**
 * @author Spencer Gibb
 * @author Nikita Konev
 */
@Component
public class GlobalWebsocketRoutingFilter implements GlobalFilter, Ordered {

	private static final Log log = LogFactory.getLog(GlobalWebsocketRoutingFilter.class);

	private final WebSocketClient webSocketClient;

	private final WebSocketService webSocketService;

	private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;

	// do not use this headersFilters directly, use getHeadersFilters() instead.
	private volatile List<HttpHeadersFilter> headersFilters;

	@Autowired
	private UserRedisUtils userRedisUtils;

	public GlobalWebsocketRoutingFilter(WebSocketClient webSocketClient, WebSocketService webSocketService,
			ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
		this.webSocketClient = webSocketClient;
		this.webSocketService = webSocketService;
		this.headersFiltersProvider = headersFiltersProvider;
	}

	/* for testing */
	static String convertHttpToWs(String scheme) {
		scheme = scheme.toLowerCase();
		return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme;
	}

	@Override
	public int getOrder() {
		// Before NettyRoutingFilter since this routes certain http requests
		return Ordered.LOWEST_PRECEDENCE - 2;
	}

	@Override
	public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
		changeSchemeIfIsWebSocketUpgrade(exchange);

		URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
		String scheme = requestUrl.getScheme();

		if (isAlreadyRouted(exchange) || (!"ws".equals(scheme) && !"wss".equals(scheme))) {
			return chain.filter(exchange);
		}

		if (!authorizationWebsocket(exchange)) {
			exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
			return exchange.getResponse().setComplete();
		}

		setAlreadyRouted(exchange);

		HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange);
		HttpHeaders headers = exchange.getRequest().getHeaders();

		List<String> protocols = headers.get(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL);
		if (protocols != null) {
			protocols = headers.get(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL).stream()
					.flatMap(header -> Arrays.stream(commaDelimitedListToStringArray(header))).map(String::trim)
					.collect(Collectors.toList());
		}
		return this.webSocketService.handleRequest(exchange,
				new ProxyWebSocketHandler(requestUrl, this.webSocketClient, filtered, protocols));
	}

	/* for testing */ List<HttpHeadersFilter> getHeadersFilters() {
		if (this.headersFilters == null) {
			this.headersFilters = this.headersFiltersProvider.getIfAvailable(ArrayList::new);

			// remove host header unless specifically asked not to
			headersFilters.add((headers, exchange) -> {
				HttpHeaders filtered = new HttpHeaders();
				filtered.addAll(headers);
				filtered.remove(HttpHeaders.HOST);
				boolean preserveHost = exchange.getAttributeOrDefault(PRESERVE_HOST_HEADER_ATTRIBUTE, false);
				if (preserveHost) {
					String host = exchange.getRequest().getHeaders().getFirst(HttpHeaders.HOST);
					filtered.add(HttpHeaders.HOST, host);
				}
				return filtered;
			});

			headersFilters.add((headers, exchange) -> {
				HttpHeaders filtered = new HttpHeaders();
				headers.entrySet().stream().filter(entry -> !entry.getKey().toLowerCase().startsWith("sec-websocket"))
						.forEach(header -> filtered.addAll(header.getKey(), header.getValue()));
				return filtered;
			});
		}

		return this.headersFilters;
	}

	static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
		// Check the Upgrade
		URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
		String scheme = requestUrl.getScheme().toLowerCase();
		String upgrade = exchange.getRequest().getHeaders().getUpgrade();
		// change the scheme if the socket client send a "http" or "https"
		if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) {
			String wsScheme = convertHttpToWs(scheme);
			boolean encoded = containsEncodedParts(requestUrl);
			URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme).build(encoded).toUri();
			exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
			if (log.isTraceEnabled()) {
				log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
			}
		}
	}

	private static class ProxyWebSocketHandler implements WebSocketHandler {

		private final WebSocketClient client;

		private final URI url;

		private final HttpHeaders headers;

		private final List<String> subProtocols;

		ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
			this.client = client;
			this.url = url;
			this.headers = headers;
			this.subProtocols = protocols;
		}

		@Override
		public List<String> getSubProtocols() {
			return this.subProtocols;
		}

		@Override
		public Mono<Void> handle(WebSocketSession session) {
			// pass headers along so custom headers can be sent through
			return client.execute(url, this.headers, new WebSocketHandler() {
				@Override
				public Mono<Void> handle(WebSocketSession proxySession) {
					Mono<Void> serverClose = proxySession.closeStatus().flatMap(session::close);
//					Mono<Void> proxyClose = session.closeStatus().flatMap(proxySession::close);
					// Use retain() for Reactor Netty
					Mono<Void> proxySessionSend = proxySession
							.send(session.receive().doOnNext(WebSocketMessage::retain));
					// .log("proxySessionSend", Level.FINE);
					Mono<Void> serverSessionSend = session
							.send(proxySession.receive().doOnNext(WebSocketMessage::retain));
					// .log("sessionSend", Level.FINE);
					return Mono.zip(proxySessionSend, serverSessionSend, serverClose).then();
				}

				/**
				 * Copy subProtocols so they are available downstream.
				 * 
				 * @return
				 */
				@Override
				public List<String> getSubProtocols() {
					return ProxyWebSocketHandler.this.subProtocols;
				}
			});
		}

	}

	// 鉴权websocket,子协议中的token
	private boolean authorizationWebsocket(ServerWebExchange exchange) {
		ServerHttpRequest request = exchange.getRequest();
		String token = request.getHeaders().getFirst(GatewayConstant.WebSocket.SEC_WEBSOCKET_PROTOCOL);
		if (StringUtils.isNotBlank(token)) {
			String userId = JwtToken.getPrimaryKey(token);
			if (StringUtils.isBlank(userId)) {
				return false;
			}
			Object user = userRedisUtils.get(userId, token);
			return user != null;
		} else {
			return false;
		}
	}

}

  1. 其他服务中怎么根据动态协议进行配合,代码参考
@Slf4j
@ChannelHandler.Sharable
public class HttpHandler extends SimpleChannelInboundHandler<FullHttpRequest> {

	@Override
	public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
		if (msg instanceof FullHttpRequest) {
			log.info("http请求,第一次握手:协议升级");
			handleHttpRequest(ctx, (FullHttpRequest) msg);
			log.info("------------------" + NoticeChannelManager.count());
			ctx.fireChannelRead(msg.retain());
		} else {
			ctx.close();
		}
	}

	private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
		if (!req.decoderResult().isSuccess() || (!"websocket".equals(req.headers().get("Upgrade")))) {
			System.err.println("// 错误方式连接,拒绝,返回");
			sendHttpResponse(ctx, req,
					new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
			return;
		}
		// 握手实例管理
		String token = req.headers().get("Sec-WebSocket-Protocol");
		WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory("ws:/" + ctx.channel(), token,
				false);
		ctx.channel().attr(NoticeChannelAttr.USER_ID).set(JwtToken.getPrimaryKey(token));
		NoticeChannelManager.add(ctx.channel());
		WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
		if (handshaker == null) {
			WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
		} else {
			handshaker.handshake(ctx.channel(), req);
		}
	}
}

通过netty实现websocket

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Spring Cloud Gateway 本身是一个基于 WebFlux 的反向代理,因此可以通过 WebFlux 的 WebSocket 支持来支持 WebSocket。具体地,你可以通过编写一个自定义的 GatewayFilter 来实现 WebSocket 的支持。以下是一个简单的示例: ```java @Component public class WebSocketGatewayFilter implements GatewayFilter { @Autowired private WebSocketHandler webSocketHandler; @Override public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) { if (isWebSocketRequest(exchange)) { return handleWebSocketRequest(exchange); } return chain.filter(exchange); } private boolean isWebSocketRequest(ServerWebExchange exchange) { HttpHeaders headers = exchange.getRequest().getHeaders(); return headers.containsValue(HttpHeaders.UPGRADE, "websocket", true) && headers.contains(HttpHeaders.CONNECTION, "Upgrade"); } private Mono<Void> handleWebSocketRequest(ServerWebExchange exchange) { return Mono.defer(() -> { ServerHttpRequest request = exchange.getRequest(); ServerHttpResponse response = exchange.getResponse(); HttpHeaders headers = response.getHeaders(); headers.set(HttpHeaders.UPGRADE, HttpHeaders.UPGRADE); headers.set(HttpHeaders.CONNECTION, HttpHeaders.UPGRADE); return webSocketHandler.handleRequest(exchange, webSocketSession -> { // WebSocket session is established }); }); } } ``` 在这个示例,`WebSocketGatewayFilter` 是一个自定义的 GatewayFilter,它负责检查是否为 WebSocket 请求,并且在是 WebSocket 请求时,调用 `handleWebSocketRequest` 方法来处理 WebSocket 请求。在 `handleWebSocketRequest` 方法,我们首先设置响应头,然后使用 `webSocketHandler` 处理 WebSocket 请求。`webSocketHandler` 是一个实现了 `WebSocketHandler` 接口的 Spring Bean,它会在 WebSocket 会话建立时被调用。 你还需要在应用程序配置文件添加以下配置来启用 WebSocket 支持: ``` spring: cloud: gateway: websockets: enabled: true ``` 完成上述步骤后,你就可以使用 Spring Cloud Gateway 支持 WebSocket 了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值