自定义限流注解,当频繁调用注解的方法时,会触发限流规则(限流规则支持配置多个),比如10s内只支持20次调用,当超过20次时会被限流。
限流注解RateLimiter,包含限流key、类型、限流规则等信息
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RateLimiter {
/**
* 限流key
*/
String key() default "redis_rate_limit";
/**
* 限流类型 ( 默认 Ip 模式 )
*/
LimitTypeEnum limitType() default LimitTypeEnum.IP;
/**
* 错误提示
*/
String message() default "失败";
/**
* 限流规则 (规则不可变,可多规则)
*/
RateRule[] rules() default {};
/**
* 防重复提交值
*/
boolean preventDuplicate() default false;
/**
* 防重复提交默认值
*/
RateRule preventDuplicateRule() default @RateRule(count = 1, time = 5);
}
限流规则,默认60s允许10次,支持配置多个规则
@Target(ElementType.ANNOTATION_TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RateRule {
/**
* 限流次数
*/
long count() default 10;
/**
* 限流时间
*/
long time() default 60;
/**
* 限流时间单位
*/
TimeUnit timeUnit() default TimeUnit.SECONDS;
}
支持IP、客户userId、全局类型,也可以自定义增加其他类型
public enum LimitTypeEnum {
IP,
USER_ID,
GLOBAL;
}
@Aspect
@Component
@Slf4j
public class RateLimiterAop {
@Autowired
private RedisTemplate<String, Object> redisTemplate;
private final DefaultRedisScript<Boolean> rateLimitScript = new DefaultRedisScript<>();
private final static String RATE_LUA_NAME = "Redis限流脚本";
private final static String RATE_LUA = "-- 1. 获取参数\n" +
"local key = KEYS[1]\n" +
"local currentTime = KEYS[2]\n" +
"-- 2. 以数组最大值为 ttl 最大值\n" +
"local expireTime = -1;\n" +
"-- 3. 遍历数组查看是否越界\n" +
"for i = 1, #ARGV, 2 do\n" +
" local rateRuleCount = tonumber(ARGV[i])\n" +
" local rateRuleTime = tonumber(ARGV[i + 1])\n" +
" -- 3.1 判断在单位时间内访问次数\n" +
" local count = redis.call('ZCOUNT', key, currentTime - rateRuleTime, currentTime)\n" +
" -- 3.2 判断是否超过规定次数\n" +
" if tonumber(count) >= rateRuleCount then\n" +
" return true\n" +
" end\n" +
" -- 3.3 判断元素最大值,设置为最终过期时间\n" +
" if rateRuleTime > expireTime then\n" +
" expireTime = rateRuleTime\n" +
" end\n" +
"end\n" +
"-- 4. 更新缓存过期时间\n" +
"redis.call('PEXPIRE', key, expireTime)\n" +
"-- 5. 删除最大时间限度之前的数据,防止数据过多\n" +
"redis.call('ZREMRANGEBYSCORE', key, 0, currentTime - expireTime)\n" +
"-- 6. redis 中添加当前时间 ( 解决多个线程在同一毫秒添加相同 value 导致 Redis 漏记的问题 )\n" +
"-- 6.1 maxRetries 最大重试次数 retries 重试次数\n" +
"local maxRetries = 5\n" +
"local retries = 0\n" +
"while true do\n" +
" local result = redis.call('ZADD', key, currentTime, currentTime)\n" +
" if result == 1 then\n" +
" -- 6.2 添加成功则跳出循环\n" +
" break\n" +
" else\n" +
" -- 6.3 未添加成功则 value + 1 再次进行尝试\n" +
" retries = retries + 1\n" +
" if retries >= maxRetries then\n" +
" -- 6.4 超过最大尝试次数 采用添加随机数策略\n" +
" local random_value = math.random(1, 1000)\n" +
" currentTime = currentTime + random_value\n" +
" else\n" +
" currentTime = currentTime + 1\n" +
" end\n" +
" end\n" +
"end\n" +
" \n" +
"return false";
@PostConstruct
public void loadScript() {
rateLimitScript.setScriptText(RATE_LUA);
rateLimitScript.setResultType(Boolean.class);
loadRedisScript(rateLimitScript, RATE_LUA_NAME);
}
private void loadRedisScript(DefaultRedisScript<Boolean> redisScript, String luaName) {
try {
List<Boolean> results = redisTemplate.getConnectionFactory().getConnection()
.scriptExists(redisScript.getSha1());
if (Boolean.FALSE.equals(results.get(0))) {
String sha = redisTemplate.getConnectionFactory().getConnection()
.scriptLoad(redisScript.getScriptAsString().getBytes(StandardCharsets.UTF_8));
log.info("预加载lua脚本成功:{}, sha=[{}]", luaName, sha);
}
} catch (Exception e) {
log.error("预加载lua脚本异常:{}", luaName, e);
}
}
/**
* 限流
* @param joinPoint joinPoint
* @param rateLimiter 限流注解
*/
@Before(value = "@annotation(rateLimiter)")
public void boBefore(JoinPoint joinPoint, RateLimiter rateLimiter) throws Exception {
// 1. 生成 key
String key = getCombineKey(rateLimiter, joinPoint);
// 2. 执行脚本返回是否限流 lua脚本参数1为key、参数2为当前时间
Boolean flag = redisTemplate.execute(rateLimitScript,
List.of(key, String.valueOf(System.currentTimeMillis())),
(Object[]) getRules(rateLimiter));
// 3. 判断是否限流
if (Boolean.TRUE.equals(flag)) {
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
String ip = requestAttributes.getRequest().getHeader("ip");
log.error("ip: '{}' 拦截到一个请求 RedisKey: '{}'", ip, key);
throw new Exception(rateLimiter.message());
}
}
/**
* 通过 rateLimiter 和 joinPoint 拼接 prefix : ip / userId : classSimpleName - methodName
*
* @param rateLimiter 提供 prefix
* @param joinPoint 提供 classSimpleName : methodName
* @return
*/
public String getCombineKey(RateLimiter rateLimiter, JoinPoint joinPoint) {
StringBuffer key = new StringBuffer(rateLimiter.key());
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
String userId = requestAttributes.getRequest().getHeader("userId");
String ip = requestAttributes.getRequest().getHeader("ip");
// 不同限流类型使用不同的前缀
switch (rateLimiter.limitType()) {
// XXX 可以新增通过参数指定参数进行限流
case IP:
key.append(ip).append(":");
break;
case USER_ID:
key.append(userId).append(":");
break;
case GLOBAL:
break;
}
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
key.append(targetClass.getSimpleName()).append("-").append(method.getName());
return key.toString();
}
/**
* 获取规则
*
* @param rateLimiter 获取其中规则信息
* @return
*/
private Long[] getRules(RateLimiter rateLimiter) {
int capacity = rateLimiter.rules().length << 1;
// 1. 构建 args
Long[] args = new Long[rateLimiter.preventDuplicate() ? capacity + 2 : capacity];
// 3. 记录数组元素
int index = 0;
// 2. 判断是否需要添加防重复提交到redis进行校验
if (rateLimiter.preventDuplicate()) {
RateRule preventRateRule = rateLimiter.preventDuplicateRule();
args[index++] = preventRateRule.count();
args[index++] = preventRateRule.timeUnit().toMillis(preventRateRule.time());
}
RateRule[] rules = rateLimiter.rules();
for (RateRule rule : rules) {
args[index++] = rule.count();
args[index++] = rule.timeUnit().toMillis(rule.time());
}
return args;
}
}
测试Controller
@RestController
public class RateLimiterTest {
@RequestMapping("/rateLimiter")
@RateLimiter(
// 60秒内只能访问10次,支持配置多个RateRule
rules = {@RateRule(count = 10, time = 60, timeUnit = TimeUnit.SECONDS),}
)
public void rateTest() {
System.out.println("测试RateLimiter");
}
}
使用postman访问/rateLimiter接口,在http请求头添加参数ip(默认使用ip限流),当60s内超过10次时会被限流,打印限流日志