背景介绍
项目使用的技术栈是Spring Cloud,有个功能需求是:
业务上,在Spring Cloud Gateway模块的服务已经可以获取到token,并且已实现鉴权通过后从token获取到身份信息;
现在希望把身份信息,填充到request参数里面(这里把多个数据封装成一个BaseDTO对象,用于扩展)。
后续处理具体业务的微服务模块,在controller层的方法传参,只要继承了BaseDTO对象,就可以直接获取到身份信息,用于业务逻辑处理。
问题描述
简单来说,问题就是 Spring Cloud Gateway 如何动态添加请求参数。
Spring Cloud Gateway Add Request Parameter
- 查看官方文档,提供了下面的示例:
docs.spring.io/spring-clou…
但是是在配置文件写明的,看起来好像只能是固定值。
- 在github上看到也有人提了类似问题,
github.com/spring-clou…
但是实现的效果也跟配置文件差不多。
- 在stackoverflow上也查了类似回答:
stackoverflow.com/questions/6…
大概思路有了方向。
解决方案
在 Spring Cloud Gateway 源码上,发现了这两个类 AddRequestParameterGatewayFilterFactory 和 ModifyRequestBodyGatewayFilterFactory
代码内容如下:
AddRequestParameterGatewayFilterFactory
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package org.springframework.cloud.gateway.filter.factory;
import java.net.URI;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractNameValueGatewayFilterFactory.NameValueConfig;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
public class AddRequestParameterGatewayFilterFactory extends AbstractNameValueGatewayFilterFactory {
public AddRequestParameterGatewayFilterFactory() {
}
public GatewayFilter apply(NameValueConfig config) {
return (exchange, chain) -> {
URI uri = exchange.getRequest().getURI();
StringBuilder query = new StringBuilder();
String originalQuery = uri.getRawQuery();
if (StringUtils.hasText(originalQuery)) {
query.append(originalQuery);
if (originalQuery.charAt(originalQuery.length() - 1) != '&') {
query.append('&');
}
}
query.append(config.getName());
query.append('=');
query.append(config.getValue());
try {
URI newUri = UriComponentsBuilder.fromUri(uri).replaceQuery(query.toString()).build(true).toUri();
ServerHttpRequest request = exchange.getRequest().mutate().uri(newUri).build();
return chain.filter(exchange.mutate().request(request).build());
} catch (RuntimeException var8) {
throw new IllegalStateException("Invalid URI query: \"" + query.toString() + "\"");
}
};
}
}
复制代码
ModifyRequestBodyGatewayFilterFactory
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package org.springframework.cloud.gateway.filter.factory.rewrite;
import java.util.Map;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.DefaultServerRequest;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.ServerRequest;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class ModifyRequestBodyGatewayFilterFactory extends AbstractGatewayFilterFactory<ModifyRequestBodyGatewayFilterFactory.Config> {
public ModifyRequestBodyGatewayFilterFactory() {
super(ModifyRequestBodyGatewayFilterFactory.Config.class);
}
/** @deprecated */
@Deprecated
public ModifyRequestBodyGatewayFilterFactory(ServerCodecConfigurer codecConfigurer) {
this();
}
public GatewayFilter apply(ModifyRequestBodyGatewayFilterFactory.Config config) {
return (exchange, chain) -> {
Class inClass = config.getInClass();
ServerRequest serverRequest = new DefaultServerRequest(exchange);
Mono<?> modifiedBody = serverRequest.bodyToMono(inClass).flatMap((o) -> {
return config.rewriteFunction.apply(exchange, o);
});
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, config.getOutClass());
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, exchange.getRequest().getHeaders());
return bodyInserter.insert(outputMessage, new BodyInserterContext()).then(Mono.defer(() -> {
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
public HttpHeaders getHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(super.getHeaders());
httpHeaders.set("Transfer-Encoding", "chunked");
return httpHeaders;
}
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
};
return chain.filter(exchange.mutate().request(decorator).build());
}));
};
}
public static class Config {
private Class inClass;
private Class outClass;
private Map<String, Object> inHints;
private Map<String, Object> outHints;
private RewriteFunction rewriteFunction;
public Config() {
}
public Class getInClass() {
return this.inClass;
}
public ModifyRequestBodyGatewayFilterFactory.Config setInClass(Class inClass) {
this.inClass = inClass;
return this;
}
public Class getOutClass() {
return this.outClass;
}
public ModifyRequestBodyGatewayFilterFactory.Config setOutClass(Class outClass) {
this.outClass = outClass;
return this;
}
public Map<String, Object> getInHints() {
return this.inHints;
}
public ModifyRequestBodyGatewayFilterFactory.Config setInHints(Map<String, Object> inHints) {
this.inHints = inHints;
return this;
}
public Map<String, Object> getOutHints() {
return this.outHints;
}
public ModifyRequestBodyGatewayFilterFactory.Config setOutHints(Map<String, Object> outHints) {
this.outHints = outHints;
return this;
}
public RewriteFunction getRewriteFunction() {
return this.rewriteFunction;
}
public <T, R> ModifyRequestBodyGatewayFilterFactory.Config setRewriteFunction(Class<T> inClass, Class<R> outClass, RewriteFunction<T, R> rewriteFunction) {
this.setInClass(inClass);
this.setOutClass(outClass);
this.setRewriteFunction(rewriteFunction);
return this;
}
public ModifyRequestBodyGatewayFilterFactory.Config setRewriteFunction(RewriteFunction rewriteFunction) {
this.rewriteFunction = rewriteFunction;
return this;
}
}
}
复制代码
实际上,可以当作官方提供的参考示例。
照着类似内容,我们可以依样画葫芦,在自己的网关过滤器上实现添加参数的功能。
实现代码
鉴权过滤器主要处理流程
@Component
public class AuthFilter implements GlobalFilter, Ordered {
private static final Logger LOGGER = LoggerFactory.getLogger(AuthFilter.class);
private static AntPathMatcher antPathMatcher;
static {
antPathMatcher = new AntPathMatcher();
}
@Override
public int getOrder() {
return FilterOrderConstant.getOrder(this.getClass().getName());
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
URI uri = request.getURI();
String url = uri.getPath();
String host = uri.getHost();
// 跳过不需要验证的路径
Stream<String> skipAuthUrls = UrlConstant.skipAuthUrls.stream();
if(skipAuthUrls.anyMatch(path -> antPathMatcher.match(path, url))){
// 直接返回
ServerHttpRequest.Builder builder = request.mutate();
return chain.filter(exchange.mutate().request(builder.build()).build());
}
// 从请求头中取出token
String token = request.getHeaders().getFirst("Authorization");
// 取出token包含的身份信息
// 校验token逻辑不再阐述
BaseDTO baseDTO = getClaim(token);
if(null == baseDTO){
// 鉴权不通过,拿不到身份信息
return illegalResponse(exchange, "{\"code\": \"401\",\"msg\": \"unauthorized.\"}");
}
// 将现在的request,添加当前身份信息
ServerHttpRequest.Builder builder = request.mutate();
Stream<String> addRequestParameterUrls = UrlConstant.addRequestParameterUrls.stream();
if (addRequestParameterUrls.anyMatch(path -> antPathMatcher.match(path, url))){
// 需要添加请求参数
if(request.getMethod() == HttpMethod.GET){
// get请求 处理参数
return addParameterForGetMethod(exchange, chain, uri, baseDTO, builder);
}
if(request.getMethod() == HttpMethod.POST){
// post请求 处理参数
MediaType contentType = request.getHeaders().getContentType();
if(MediaType.APPLICATION_JSON.equals(contentType)
|| MediaType.APPLICATION_JSON_UTF8.equals(contentType)){
// 请求内容为 application json
return addParameterForPostMethod(exchange, chain, baseDTO);
}
if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
// 请求内容为 form data
return addParameterForFormData(exchange, chain, baseDTO, builder);
}
}
if(request.getMethod() == HttpMethod.PUT){
// put请求 处理参数
// 走 post 请求流程
return addParameterForPostMethod(exchange, chain, baseDTO);
}
if(request.getMethod() == HttpMethod.DELETE){
// delete请求 处理参数
// 走 get 请求流程
return addParameterForGetMethod(exchange, chain, uri, baseDTO, builder);
}
}
// 当前过滤器filter执行结束
return chain.filter(exchange.mutate().request(builder.build()).build());
}
}
复制代码
Get请求 添加参数
/**
* get请求,添加参数
* @param exchange
* @param chain
* @param uri
* @param baseDTO
* @param builder
* @return
*/
private Mono<Void> addParameterForGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, URI uri, BaseDTO baseDTO, ServerHttpRequest.Builder builder) {
StringBuilder query = new StringBuilder();
String originalQuery = uri.getQuery();
if (StringUtils.hasText(originalQuery)) {
query.append(originalQuery);
if (originalQuery.charAt(originalQuery.length() - 1) != '&') {
query.append('&');
}
}
query.append("userId").append("=").append(baseDTO.getUserId())
.append("&").append("userName").append("=").append(baseDTO.getUserName())
;
try {
URI newUri = UriComponentsBuilder.fromUri(uri).replaceQuery(query.toString()).build().encode().toUri();
ServerHttpRequest request = exchange.getRequest().mutate().uri(newUri).build();
return chain.filter(exchange.mutate().request(request).build());
} catch (Exception e) {
LOGGER.error("Invalid URI query: " + query.toString(), e);
// 当前过滤器filter执行结束
return chain.filter(exchange.mutate().request(builder.build()).build());
}
}
复制代码
Post请求 添加参数
请求内容为 application json
/**
* post请求,添加参数
* @param exchange
* @param chain
* @param baseDTO
* @return
*/
private Mono<Void> addParameterForPostMethod(ServerWebExchange exchange, GatewayFilterChain chain, BaseDTO baseDTO) {
ServerRequest serverRequest = new DefaultServerRequest(exchange);
AtomicBoolean flag = new AtomicBoolean(false);
Mono<String> modifiedBody = serverRequest.bodyToMono(String.class).flatMap((o) -> {
if(o.startsWith("[")){
// body内容为数组,直接返回
return Mono.just(o);
}
ObjectMapper objectMapper = new ObjectMapper();
try {
Map map = objectMapper.readValue(o, Map.class);
map.put("userId", baseDTO.getUserId());
map.put("userName", baseDTO.getUserName());
String json = objectMapper.writeValueAsString(map);
LOGGER.info("addParameterForPostMethod -> json = {}", json);
return Mono.just(json);
}catch (Exception e){
e.printStackTrace();
return Mono.just(o);
}
});
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, exchange.getRequest().getHeaders());
return bodyInserter.insert(outputMessage, new BodyInserterContext()).then(Mono.defer(() -> {
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
public HttpHeaders getHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(super.getHeaders());
httpHeaders.set("Transfer-Encoding", "chunked");
return httpHeaders;
}
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
};
return chain.filter(exchange.mutate().request(decorator).build());
}));
}
复制代码
请求内容为 form data
/**
* post请求,form data 添加参数
* @param exchange
* @param chain
* @param baseDTO
* @param builder
* @return
*/
private Mono<Void> addParameterForFormData(ServerWebExchange exchange, GatewayFilterChain chain, BaseDTO baseDTO, ServerHttpRequest.Builder builder) {
builder.header("userId", String.valueOf(baseDTO.getUserId()));
try {
builder.header("userName", URLEncoder.encode(String.valueOf(baseDTO.getUserName()), "UTF-8"));
} catch (UnsupportedEncodingException e) {
builder.header("userName", String.valueOf(baseDTO.getUserName()));
}
ServerHttpRequest serverHttpRequest = builder.build();
HttpHeaders headers = serverHttpRequest.getHeaders();
return chain.filter(exchange.mutate().request(serverHttpRequest).build());
}
复制代码
返回数据处理
/**
* 返回消息
* @param exchange
* @param data
* @return
*/
private Mono<Void> illegalResponse(ServerWebExchange exchange, String data) {
ServerHttpResponse originalResponse = exchange.getResponse();
originalResponse.setStatusCode(HttpStatus.OK);
originalResponse.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
byte[] response = data.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = originalResponse.bufferFactory().wrap(response);
return originalResponse.writeWith(Flux.just(buffer));
}
复制代码
最终效果
上面描述,已实现将userId和userName两个属性,写入到request参数中。
在具体业务处理的服务模块,controller层的传参,只要继承包含userId和userName两个属性的BaseDTO类,就可以拿到该信息,用于实际的业务流程。