【SpringCloud】Gateway自定义RoutePredicateFactory路由断言工厂、GatewayFilterFactory路由过滤器工厂和全局过滤器
文章目录
1. 自定义路由断言工厂
1.1 概述
Gateway内置了十几种路由断言工厂,比如 After
、Before
、Between
等,路由断言工厂的架构图如下:
现在有一个需求:
用户种类有两种,一种是普通用户,另一种是vip用户。只有vip用户才能够访问接口。
Gateway提供的现有路由断言工厂都无法实现这个需求,那么就需要我们自定义一个路由断言工厂。
1.2 编码实现
自定义路由断言工厂的实现可以参考Gateway内置的路由断言工厂,如AfterRoutePredicateFactory
:
public class AfterRoutePredicateFactory extends AbstractRoutePredicateFactory<AfterRoutePredicateFactory.Config> {
/**
* DateTime key.
*/
public static final String DATETIME_KEY = "datetime";
public AfterRoutePredicateFactory() {
super(Config.class);
}
@Override
public List<String> shortcutFieldOrder() {
return Collections.singletonList(DATETIME_KEY);
}
@Override
public Predicate<ServerWebExchange> apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange serverWebExchange) {
final ZonedDateTime now = ZonedDateTime.now();
return now.isAfter(config.getDatetime());
}
@Override
public Object getConfig() {
return config;
}
@Override
public String toString() {
return String.format("After: %s", config.getDatetime());
}
};
}
public static class Config {
@NotNull
private ZonedDateTime datetime;
public ZonedDateTime getDatetime() {
return datetime;
}
public void setDatetime(ZonedDateTime datetime) {
this.datetime = datetime;
}
}
}
参考它的格式编写我们自己的路由断言工厂,规则如下:
- 要么继承
AbstractRoutePredicateFactory
抽象类 - 要么实现
RoutePredicateFactory
接口 - 自定义路由断言工厂的名字开头随意,但是必须以
RoutePredicateFactory
结尾。(*RoutePredicateFactory
)
自定义断言工厂代码如下:
@Component
public class CustomRoutePredicateFactory extends AbstractRoutePredicateFactory<CustomRoutePredicateFactory.Config> {
/**
* userType key.
*/
public static final String USERTYPE_KEY = "userType";
public CustomRoutePredicateFactory() {
super(Config.class);
}
@Override
public List<String> shortcutFieldOrder() {
return Collections.singletonList(USERTYPE_KEY);
}
@Override
public Predicate<ServerWebExchange> apply(Config config) {
return new Predicate<ServerWebExchange>() {
@Override
public boolean test(ServerWebExchange serverWebExchange) {
//检查request的参数里面,userType是否为指定的值,符合配置就通过
String userType = serverWebExchange.getRequest().getQueryParams().getFirst("userType");
if (userType == null) return false;
//如果说参数存在,就和config的数据进行比较
if (userType.equals(config.getUserType())) {
return true;
}
return false;
}
};
}
public static class Config {
@NotNull
private String userType;
public String getUserType() {
return userType;
}
public void setUserType(String userType) {
this.userType = userType;
}
}
}
1.3 测试
在配置文件中添加如下配置:
gateway:
routes:
- id: xxx
uri: xxxx
predicates:
- Custom=vip
访问http://localhost:9527/pay/gateway/get/1?userType=common
地址,无法访问:
访问http://localhost:9527/pay/gateway/get/1?userType=vip
地址,可以访问:
2. 自定义路由过滤器工厂
2.1 概述
Gateway内置了三十多种路由过滤器,架构如下:
现在有一个需求:
经过网关的请求必须携带一个参数,参数名必须为param
。使用自定义路由过滤器实现。
2.2 编码实现
具体实现可以参考Gateway提供的路由过滤器工厂 AddRequestHeadersIfNotPresentGatewayFilterFactory
:
public class AddRequestHeadersIfNotPresentGatewayFilterFactory
extends AbstractGatewayFilterFactory<AddRequestHeadersIfNotPresentGatewayFilterFactory.KeyValueConfig> {
@Override
public GatewayFilter apply(KeyValueConfig config) {
return new GatewayFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest.Builder requestBuilder = null;
Map<String, List<String>> aggregatedHeaders = new HashMap<>();
for (KeyValue keyValue : config.getKeyValues()) {
String key = keyValue.getKey();
List<String> candidateValue = aggregatedHeaders.get(key);
if (candidateValue == null) {
candidateValue = new ArrayList<>();
candidateValue.add(keyValue.getValue());
}
else {
candidateValue.add(keyValue.getValue());
}
aggregatedHeaders.put(key, candidateValue);
}
for (Map.Entry<String, List<String>> kv : aggregatedHeaders.entrySet()) {
String headerName = kv.getKey();
boolean headerIsMissingOrBlank = exchange.getRequest().getHeaders().getOrEmpty(headerName).stream()
.allMatch(h -> !StringUtils.hasText(h));
if (headerIsMissingOrBlank) {
if (requestBuilder == null) {
requestBuilder = exchange.getRequest().mutate();
}
ServerWebExchange finalExchange = exchange;
requestBuilder.headers(httpHeaders -> {
List<String> replacedValues = kv.getValue().stream()
.map(value -> ServerWebExchangeUtils.expand(finalExchange, value))
.collect(Collectors.toList());
httpHeaders.addAll(headerName, replacedValues);
});
}
}
if (requestBuilder != null) {
exchange = exchange.mutate().request(requestBuilder.build()).build();
}
return chain.filter(exchange);
}
@Override
public String toString() {
ToStringCreator toStringCreator = filterToStringCreator(
AddRequestHeadersIfNotPresentGatewayFilterFactory.this);
for (KeyValue keyValue : config.getKeyValues()) {
toStringCreator.append(keyValue.getKey(), keyValue.getValue());
}
return toStringCreator.toString();
}
};
}
public ShortcutType shortcutType() {
return ShortcutType.GATHER_LIST;
}
@Override
public List<String> shortcutFieldOrder() {
return Collections.singletonList("keyValues");
}
@Override
public KeyValueConfig newConfig() {
return new KeyValueConfig();
}
@Override
public Class<KeyValueConfig> getConfigClass() {
return KeyValueConfig.class;
}
/**
* @deprecated in favour of
* {@link org.springframework.cloud.gateway.support.config.KeyValueConfig}
*/
@Deprecated
public static class KeyValueConfig {
private KeyValue[] keyValues;
public KeyValue[] getKeyValues() {
return keyValues;
}
public void setKeyValues(KeyValue[] keyValues) {
this.keyValues = keyValues;
}
}
/**
* @deprecated in favour of
* {@link org.springframework.cloud.gateway.support.config.KeyValue}
*/
@Deprecated
public static class KeyValue {
private final String key;
private final String value;
public KeyValue(String key, String value) {
this.key = key;
this.value = value;
}
public String getKey() {
return key;
}
public String getValue() {
return value;
}
@Override
public String toString() {
return new ToStringCreator(this).append("name", key).append("value", value).toString();
}
}
}
参考它的格式,自定义路由过滤器工厂的规则如下:
- 继承
AbstractGatewayFilterFactory
抽象类 - 类名开头随意,但是必须以
GatewayFilterFactory
结尾。(*GatewayFilterFactory
)
代码如下:
@Component
public class CustomGatewayFilterFactory extends AbstractGatewayFilterFactory<CustomGatewayFilterFactory.Config> {
public CustomGatewayFilterFactory() {
super(CustomGatewayFilterFactory.Config.class);
}
@Override
public GatewayFilter apply(CustomGatewayFilterFactory.Config config) {
return new GatewayFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
System.out.println("进入了自定义网关过滤器CustomGatewayFilterFactory,param:" + config.getStatus());
if (request.getQueryParams().containsKey("param")) {
return chain.filter(exchange);
} else {
exchange.getResponse().setStatusCode(HttpStatus.BAD_REQUEST);
return exchange.getResponse().setComplete();
}
}
};
}
@Override
public List<String> shortcutFieldOrder() {
return Arrays.asList("param");
}
public static class Config {
@Getter
@Setter
private String status;//设定一个状态值/标志位,它等于多少,匹配和才可以访问
}
}
2.3 测试
在配置文件进行配置:
gateway:
routes:
- id: xxx
uri: xxx
predicates:
- Custom=vip
filters:
- Custom=param
访问 http://localhost:9527/pay/gateway/filter
地址,失败。
访问 http://localhost:9527/pay/gateway/filter?param=java
地址,成功。
3. 自定义全局过滤器
需求,定义一个全局过滤器统计所有接口的耗时。
代码如下:
@Component
@Slf4j
public class CustomGlobalFilter implements GlobalFilter, Ordered {
/**
* 数字越小优先级越高
*
* @return
*/
@Override
public int getOrder() {
return 0;
}
private static final String BEGIN_VISIT_TIME = "begin_visit_time";//开始访问时间
/**
* 第2版,各种统计
*
* @param exchange
* @param chain
* @return
*/
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
//先记录下访问接口的开始时间
exchange.getAttributes().put(BEGIN_VISIT_TIME, System.currentTimeMillis());
return chain.filter(exchange).then(Mono.fromRunnable(() -> {
Long beginVisitTime = exchange.getAttribute(BEGIN_VISIT_TIME);
if (beginVisitTime != null) {
log.info("访问接口主机: " + exchange.getRequest().getURI().getHost());
log.info("访问接口端口: " + exchange.getRequest().getURI().getPort());
log.info("访问接口URL: " + exchange.getRequest().getURI().getPath());
log.info("访问接口URL参数: " + exchange.getRequest().getURI().getRawQuery());
log.info("访问接口时长: " + (System.currentTimeMillis() - beginVisitTime) + "ms");
log.info("我是美丽分割线: ###################################################");
System.out.println();
}
}));
}
}