package com.jd.car.parrot.web.plugin.ratelimter; import com.google.common.collect.Lists; import com.car.parrot.web.plugin.config.Singleton; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.io.ClassPathResource; import org.springframework.data.redis.core.ReactiveRedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.data.redis.core.script.RedisScript; import org.springframework.scripting.support.ResourceScriptSource; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; /** * See https://stripe.com/blog/rate-limiters and * https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L11-L34 * See https://github.com/spring-cloud/spring-cloud-gateway/blob/master/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java * RedisRateLimiter. * * @author zero */ @Slf4j public class RedisRateLimiter { /** * logger. */ private static final Logger LOGGER = LoggerFactory.getLogger(RedisRateLimiter.class); private RedisScript<List<Long>> script; private AtomicBoolean initialized = new AtomicBoolean(false); /** * Instantiates a new Redis rate limiter. */ public RedisRateLimiter() { this.script = redisScript(); initialized.compareAndSet(false, true); } /** * This uses a basic token bucket algorithm and relies on the fact that Redis scripts * execute atomically. No other operations can run between fetching the count and * writing the new count. * * @param id is rule id * @param replenishRate replenishRate * @param burstCapacity burstCapacity * @return {@code Mono<Response>} to indicate when request processing is complete */ @SuppressWarnings("unchecked") public Mono<RateLimiterResponse> isAllowed(final String id, final double replenishRate, final double burstCapacity) { if (!this.initialized.get()) { throw new IllegalStateException("RedisRateLimiter is not initialized"); } try { long startTime = System.currentTimeMillis(); List<String> keys = getKeys(id); List<String> args = getArgs(replenishRate, burstCapacity); Flux<List<Long>> resultFlux = Singleton .INST .get(ReactiveRedisTemplate.class) .execute(this.script, keys, args); RateLimiterResponse limiterResponse = resultFlux .onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L))) .reduce(new ArrayList<Long>(), (longs, l) -> { longs.addAll(l); return longs; }).map(results -> { boolean allowed = results.get(0) == 1L; Long tokensLeft = results.get(1); return new RateLimiterResponse(allowed, tokensLeft); }).block(); log.debug("limiterResponse:{} cost:{}", limiterResponse, (System.currentTimeMillis() - startTime)); return Mono.just(limiterResponse); } catch (Exception e) { LOGGER.error("Error determining if user allowed from redis:", e); } return Mono.just(new RateLimiterResponse(true, -1)); } private static List<String> getKeys(final String id) { String prefix = "request_rate_limiter.{" + id; String tokenKey = prefix + "}.tokens"; String timestampKey = prefix + "}.timestamp"; log.debug("tokenKey:{},timestampKey:{}", tokenKey, timestampKey); return Arrays.asList(tokenKey, timestampKey); } @SuppressWarnings("unchecked") private RedisScript<List<Long>> redisScript() { DefaultRedisScript redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("/META-INF/scripts/request_rate_limiter.lua"))); redisScript.setResultType(List.class); return redisScript; } /** * 获取lua脚本执行参数 * * 注意:虽然当前电脑普遍时钟精度只到毫秒 * * @param replenishRate 令牌桶令牌生成速率 * @param burstCapacity 令牌桶容量 * @return lua脚本执行参数 */ private List<String> getArgs(final double replenishRate, final double burstCapacity) { Instant now = Instant.now(); long ninthPowerOfTen = 1000 * 1000 * 1000L; long nowNano = now.getEpochSecond() * ninthPowerOfTen + now.getNano(); ArrayList<String> scriptArgs = Lists.newArrayListWithCapacity(4); scriptArgs.add(replenishRate + ""); scriptArgs.add(burstCapacity + ""); scriptArgs.add(nowNano + ""); scriptArgs.add("1"); return scriptArgs; } }
local tokens_key = KEYS[1] local timestamp_key = KEYS[2] --redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key) local rate = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) local requested = tonumber(ARGV[4]) -- 填满需要的时间 local fill_time = capacity / rate -- 过期时间 local ttl = math.floor(fill_time * 2) --redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) --redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) --redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) --redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) --redis.log(redis.LOG_WARNING, "filltime " .. fill_time) --redis.log(redis.LOG_WARNING, "ttl " .. ttl) -- 上一次获取token后剩余数量 local last_tokens = tonumber(redis.call("get", tokens_key)) if last_tokens == nil then last_tokens = capacity end --redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens) -- 上一次生成token的时间 local last_refreshed = tonumber(redis.call("get", timestamp_key)) if last_refreshed == nil then last_refreshed = 0 end --redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed) -- 时间增量 local delta = math.max(0, now - last_refreshed) -- 已填充的令牌,min(最大令牌值,增量令牌) -- 纳秒填充速率(实际精度未毫秒) local nano_rate = rate / 1000000000 -- 间隔期内需要一次性填充的令牌数--整数 local delta_tokens = math.floor(delta * nano_rate) -- 总令牌数量 local total_tokens = last_tokens + delta_tokens -- 实际可用令牌数量(多余的流走了) local filled_tokens = math.floor(math.min(capacity, total_tokens)) -- 已填充的令牌是否满足请求的要求 local allowed = filled_tokens >= requested -- 要补充到redis里的令牌 local new_tokens = filled_tokens -- 是否允许获取到令牌 local allowed_num = 0 if allowed then -- 要补充到redis里的令牌 new_tokens = filled_tokens - requested allowed_num = 1 end redis.log(redis.LOG_WARNING, "delta " .. delta) redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens) redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num) redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens) redis.log(redis.LOG_WARNING, "ttl " .. ttl) redis.call("setex", tokens_key, ttl, new_tokens) -- 如果没有生成令牌,令牌生成时间就不改变 if delta_tokens > 0 then redis.call("setex", timestamp_key, ttl, now) else redis.call("setex", timestamp_key, ttl, last_refreshed) end return { allowed_num, new_tokens }