路由过滤器允许以某种方式修改传入的HTTP请求或传出的HTTP响应,路径过滤器的范围限定为特定路径,Spring Cloud Gateway包含许多内置的GatewayFilter工厂。
Spring Cloud Gateway限流就是通过内置的RequestRateLimiterGateWayFilterFactory工厂来实现的。
当然,官方的肯定不能满足我们部分业务需求,因此可以自定义限流过滤器。
## yml如下配置,就可以为该路由添加此拦截器:
spring:
cloud:
gateway:
routes:
- id: test_route
uri: localhost
predicates:
- Path=/host/address
filters:
- name: RequestRateLimiter
args:
## 允许用户每秒执行多少请求,而不会丢弃任何请求。这是令牌桶填充的速率。
redis-rate-limiter.replenishRate: 1
## 是一秒钟内允许用户执行的最大请求数。这是令牌桶可以容纳的令牌数。将此值设置为零将阻止所有请求。
redis-rate-limiter.burstCapacity: 3
## KeyResolver是一个简单的获取用户请求参数 我这里以主机地址为key来作限流
key-resolver: "#{@hostAddrKeyResolver}"
## RequestRateLimiterGateWayFilterFactory代码:
//AbstractGatewayFilterFactory实现GatewayFilterFactory接口,自定义的过滤工厂可以继承
//AbstractGatewayFilterFactory并编写apply方法
public class RequestRateLimiterGatewayFilterFactory extends AbstractGatewayFilterFactory<RequestRateLimiterGatewayFilterFactory.Config> {
public static final String KEY_RESOLVER_KEY = "keyResolver";
private final RateLimiter defaultRateLimiter;
private final KeyResolver defaultKeyResolver;
public RequestRateLimiterGatewayFilterFactory(RateLimiter defaultRateLimiter,
KeyResolver defaultKeyResolver) {
super(Config.class);
this.defaultRateLimiter = defaultRateLimiter;
this.defaultKeyResolver = defaultKeyResolver;
}
public KeyResolver getDefaultKeyResolver() {
return defaultKeyResolver;
}
public RateLimiter getDefaultRateLimiter() {
return defaultRateLimiter;
}
@SuppressWarnings("unchecked")
@Override
public GatewayFilter apply(Config config) {
//yml中我们配置的hostAddrKeyResolver
KeyResolver resolver = (config.keyResolver == null) ? defaultKeyResolver : config.keyResolver;
//这个就是限流的具体实现,默认使用RedisRateLimiter
RateLimiter<Object> limiter = (config.rateLimiter == null) ? defaultRateLimiter : config.rateLimiter;
return (exchange, chain) -> {
Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
return resolver.resolve(exchange).flatMap(key ->
//这里的isAllowed就是具体实现,输入参数为路由id和限流key(这里为主机地址hostAddress)
// TODO: if key is empty?
limiter.isAllowed(route.getId(), key).flatMap(response -> {
for (Map.Entry<String, String> header : response.getHeaders().entrySet()) {
exchange.getResponse().getHeaders().add(header.getKey(), header.getValue());
}
//如果为真,通过拦截
if (response.isAllowed()) {
return chain.filter(exchange);
}
//否则设置http码为429,too many request
exchange.getResponse().setStatusCode(config.getStatusCode());
return exchange.getResponse().setComplete();
}));
};
}
}
分析:
1.加载KeyResolver,从配置文件中加载,此处我配置了hostAddrKeyResolver,即根据host地址来进行限流。如果为空,使用默认的PrincipalNameKeyResolver
2.加载RateLimiter,默认使用RedisRateLimiter。
3.执行RedisRateLimiter的isAllowed方法,得到response,如果isAllowed为true则通过拦截,否则返回429(isAllowed方法具体实现下文描述)。
## HostAddrKeyResolver:
@Slf4j
public class HostAddrKeyResolver implements KeyResolver {
@Override
public Mono<String> resolve(ServerWebExchange exchange) {
log.info("HostAddrKeyResolver 限流");
return Mono.just(exchange.getRequest().getRemoteAddress().getHostName());
}
}
在启动类中注入bean
@Bean
public HostAddrKeyResolver hostAddrKeyResolver() {
return new HostAddrKeyResolver();
}
## RedisRateLimiter:
@Override
@SuppressWarnings("unchecked")
public Mono<Response> isAllowed(String routeId, String id) {
//判断是否初始化
if (!this.initialized.get()) {
throw new IllegalStateException("RedisRateLimiter is not initialized");
}
//获取配置
Config routeConfig = getConfig().getOrDefault(routeId, defaultConfig);
if (routeConfig == null) {
throw new IllegalArgumentException("No Configuration found for route " + routeId);
}
//令牌桶填充速率
int replenishRate = routeConfig.getReplenishRate();
//令牌桶可容纳令牌数
int burstCapacity = routeConfig.getBurstCapacity();
try {
//获取redis的key,执行lua脚本时传入
List<String> keys = getKeys(id);
//获取参数,执行lua脚本时传入
List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "",
Instant.now().getEpochSecond() + "", "1");
Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
// .log("redisratelimiter", Level.FINER);
return flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
.reduce(new ArrayList<Long>(), (longs, l) -> {
longs.addAll(l);
return longs;
}) .map(results -> {
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
Response response = new Response(allowed, getHeaders(routeConfig, tokensLeft));
if (log.isDebugEnabled()) {
log.debug("response: " + response);
}
return response;
});
}
catch (Exception e) {
/*
* We don't want a hard dependency on Redis to allow traffic. Make sure to set
* an alert so you know if this is happening too much. Stripe's observed
* failure rate is 0.01%.
*/
log.error("Error determining if user allowed from redis", e);
}
return Mono.just(new Response(true, getHeaders(routeConfig, -1L)));
}
@NotNull
public HashMap<String, String> getHeaders(Config config, Long tokensLeft) {
HashMap<String, String> headers = new HashMap<>();
headers.put(this.remainingHeader, tokensLeft.toString());
headers.put(this.replenishRateHeader, String.valueOf(config.getReplenishRate()));
headers.put(this.burstCapacityHeader, String.valueOf(config.getBurstCapacity()));
return headers;
}
static List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "request_rate_limiter.{" + id;
//令牌桶剩余令牌数
String tokenKey = prefix + "}.tokens";
//令牌桶最后填充令牌时间
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
分析:
1.判断是否初始化,加载配置,获取令牌填充速率和令牌桶大小
2.根据路由id组合成两个redis中的key值,传入lua脚本
request_rate_limiter.{id}.tokens 令牌桶剩余令牌数
request_rate_limiter.{id}.timestamp 令牌桶最后填充令牌时间
3.把令牌填充速率,令牌桶大小,当前时间(单位:秒),消耗令牌数(默认为1)组合传入lua脚本
4.执行lua脚本
5.flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L))) 这个是对执行lua脚本过程中发生异常的处理,它会忽略异常,返回令牌。这样就能跟redis解耦,不对它强依赖。
该实现核心主要体现在lua脚本上,它使用的是令牌桶算法
详见spring-cloud-gateway-core下的request_rate_limiter.lua
## 获取剩余令牌数的redis key
local tokens_key = KEYS[1]
## 获取最后一次填充令牌的时间
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
## 令牌填充速率
local rate = tonumber(ARGV[1])
## 令牌桶大小
local capacity = tonumber(ARGV[2])
## 当前秒数
local now = tonumber(ARGV[3])
## 消耗令牌数,默认1
local requested = tonumber(ARGV[4])
## 计算令牌桶需要填充的时间
local fill_time = capacity/rate
## 计算key的存活时间
local ttl = math.floor(fill_time*2)
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
## 获取剩余的令牌数
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
## 获取令牌最后填充时间
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
## 计算得到剩余的令牌数
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
## 大于请求消耗令牌 allowed 设为true
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed_num, new_tokens }