springcloud-gateway 基于Token限流
1、目前springcloud-gateway自带限流取决于ServerWebExchange 获取的参数,有基于IP,和基于Principal及其他自定的参数,都可以用来限流。
2、目前springcloud-gateway的限流算法有两种,一种是令牌桶算法,还有一种是漏桶算法,具体实现和原理,本文就不再重述,自行百度。
3、我们想要实现的是基于用户限流,概念摸清楚后发现第一个基于IP肯定是不行的,第二个Principal,这个是要有认证才能获取到用户,这个也不考虑。第三个是基于参数,之所以要基于token限流,是因为在前后分离的架构中,前台请求后台都要带上token才能请求接口,所以我们就基于token来实现。
开始
1、在yml中配置过滤器
routes:
- id: ms-admin #网关路由到服务admin
uri: lb://ms-admin
predicates:
- Path=/admin/**
filters:
- StripPrefix=1
- name: Hystrix
args:
name: fallbackcmd
fallbackUri: forward:/fallback
- name: RequestRateLimiter
args:
rate-limiter: "#{@customRedisRateLimiter}"
key-resolver: "#{@principalNameKeyResolver}"
在gateway的路由节点配置过滤器,RequestRateLimiter
-
key-resolver,用于限流的键的解析器的 Bean 对象的名字
-
rate-limiter,是限流具体实现的Bean对象的名字
2、编写对应的bean
@Slf4j
@Configuration
@AllArgsConstructor
public class RateLimiterConfiguration {
private final RedisTemplate redisTemplate;
private final RedisTokenStoreSerializationStrategy redisTokenStoreSerializationStrategy;
@Bean
public KeyResolver principalNameKeyResolver(){
return exchange -> {
List<String> authorization = exchange.getRequest().getHeaders().get(CommonConstants.AUTHORIZATION);
if(CollUtil.isNotEmpty(authorization)){
String token = authorization.get(0);
token = token.substring(token.indexOf(CommonConstants.PREFIX) + 1,token.length());
String key = SecurityConstants.MS_OAUTH_PREFIX + CommonConstants.AUTH_USER + token;
byte[] principal = redisTemplate.getConnectionFactory().getConnection().get(redisTokenStoreSerializationStrategy.serialize(key));
if(principal != null){
StoreUser principalStr = redisTokenStoreSerializationStrategy.deserialize(principal,StoreUser.class);
return Mono.just(principalStr.getLimitLevel() == 0 ? CommonConstants.DEFAULT_LEVEL : String.valueOf(principalStr.getLimitLevel()));
}
}
return Mono.just(CommonConstants.DEFAULT_LEVEL);
};
}
@Bean
@Primary
public RateLimiter customRedisRateLimiter(
ReactiveRedisTemplate<String, String> redisTemplate,
@Qualifier(CustomRedisRateLimiter.REDIS_SCRIPT_NAME) RedisScript<List<Long>> script,
Validator validator){
return new CustomRedisRateLimiter(redisTemplate,script,validator);
}
}
这个类中两个bean就是对应yml中配置的两个bean
在principalNameKeyResolver
bean中可以看到,从请求头中获取token,然后拼成redis的key 去redis中获取之前在认证服务中保存的用户信息,如果用户信息是空,给一个默认的限流等级
customRedisRateLimiter
bean 是继承了AbstractRateLimiter
想当于重写默认的redis限流
3、自定义的redis限流
@Slf4j
public class CustomRedisRateLimiter extends AbstractRateLimiter<CustomRedisRateLimiter.Config> implements ApplicationContextAware {
public static final String CONFIGURATION_PROPERTY_NAME = "redis-rate-limiter";
public static final String REDIS_SCRIPT_NAME = "redisRequestRateLimiterScript";
public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";
private ReactiveRedisTemplate<String, String> redisTemplate;
private RedisScript<List<Long>> script;
private AtomicBoolean initialized = new AtomicBoolean(false);
private String remainingHeader = REMAINING_HEADER;
/** The name of the header that returns the replenish rate configuration. */
private String replenishRateHeader = REPLENISH_RATE_HEADER;
/** The name of the header that returns the burst capacity configuration. */
private String burstCapacityHeader = BURST_CAPACITY_HEADER;
@Autowired
private LimiterLevelResolver limiterLevelResolver;
public CustomRedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate,RedisScript<List<Long>> script, Validator validator) {
super(Config.class , CONFIGURATION_PROPERTY_NAME , validator);
this.redisTemplate = redisTemplate;
this.script = script;
initialized.compareAndSet(false,true);
}
@Override
public Mono<RateLimiter.Response> isAllowed(String routeId, String id) {
if (!this.initialized.get()) {
throw new IllegalStateException("RedisRateLimiter is not initialized");
}
if (ObjectUtils.isEmpty(limiterLevelResolver) ){
throw new IllegalArgumentException("No Configuration found for route " + routeId);
}
RateLimiterLevel rateLimiterLevel = limiterLevelResolver.get();
// How many requests per second do you want a user to be allowed to do?
int replenishRate = rateLimiterLevel
.getLevels()
.stream()
.filter(rateLimiterVO -> rateLimiterVO.getLevel().equals(id))
.findFirst()
.map(RateLimiterVO::getReplenishRate)
.orElse(CommonConstants.DEFAULT_LIMIT_LEVEL);
// How much bursting do you want to allow?
int burstCapacity = rateLimiterLevel
.getLevels()
.stream()
.filter(rateLimiterVO -> rateLimiterVO.getLevel().equals(id))
.findFirst()
.map(RateLimiterVO::getBurstCapacity)
.orElse(CommonConstants.DEFAULT_LIMIT_LEVEL);
try {
List<String> keys = getKeys(id);
long limitTime = getTime(rateLimiterLevel
.getLevels()
.stream()
.filter(rateLimiterVO -> rateLimiterVO.getLevel().equals(id))
.findFirst()
.map(RateLimiterVO::getLimitType)
.orElse(CommonConstants.DEFAULT_LIMIT_TYPE));
List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "",limitTime + "", "1");
Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
return flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
.reduce(new ArrayList<Long>(), (longs, l) -> {
longs.addAll(l);
return longs;
}) .map(results -> {
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
RateLimiter.Response response = new RateLimiter.Response(allowed, getHeaders(replenishRate , burstCapacity , tokensLeft));
return response;
});
} catch (Exception e) {
e.printStackTrace();
}
return Mono.just(new RateLimiter.Response(true, getHeaders(replenishRate , burstCapacity , -1L)));
}
public HashMap<String, String> getHeaders(Integer replenishRate, Integer burstCapacity , Long tokensLeft) {
HashMap<String, String> headers = new HashMap<>();
headers.put(this.remainingHeader, tokensLeft.toString());
headers.put(this.replenishRateHeader, String.valueOf(replenishRate));
headers.put(this.burstCapacityHeader, String.valueOf(burstCapacity));
return headers;
}
static List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "request_user_rate_limiter.{" + id;
// You need two Redis keys for Token Bucket.
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
/**
* @Date 14:52 2019/7/15
* @Param [type] 1:秒,2:分钟,3:小时,4:天
* @return long
**/
public long getTime(int type){
long time = Instant.now().getEpochSecond();
switch (type){
case 1:
break;
case 2:
time = time / (1000 * 60);
break;
case 3:
time = time / (1000 * 60 * 60);
break;
case 4:
time = time / (1000 * 60 * 60 * 24);
break;
}
return time;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (initialized.compareAndSet(false, true)) {
this.redisTemplate = applicationContext.getBean("stringReactiveRedisTemplate",ReactiveRedisTemplate.class);
this.script = applicationContext.getBean(REDIS_SCRIPT_NAME, RedisScript.class);
if (applicationContext.getBeanNamesForType(Validator.class).length > 0) {
this.setValidator(applicationContext.getBean(Validator.class));
}
}
}
@Validated
public static class Config{
@Min(1)
private int replenishRate;
@Min(1)
private int burstCapacity = 1;
public int getReplenishRate() {
return replenishRate;
}
public Config setReplenishRate(int replenishRate) {
this.replenishRate = replenishRate;
return this;
}
public int getBurstCapacity() {
return burstCapacity;
}
public Config setBurstCapacity(int burstCapacity) {
this.burstCapacity = burstCapacity;
return this;
}
@Override
public String toString() {
return "Config{" +
"replenishRate=" + replenishRate +
", burstCapacity=" + burstCapacity +
'}';
}
}
}
核心点在于isAllowed
具体代码:传送门 : https://github.com/yzcheng90/ms