Redis实现限流功能

自定义限流注解,当频繁调用注解的方法时,会触发限流规则(限流规则支持配置多个),比如10s内只支持20次调用,当超过20次时会被限流。
限流注解RateLimiter,包含限流key、类型、限流规则等信息

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RateLimiter {
    /**
     * 限流key
     */
    String key() default "redis_rate_limit";

    /**
     * 限流类型 ( 默认 Ip 模式 )
     */
    LimitTypeEnum limitType() default LimitTypeEnum.IP;

    /**
     * 错误提示
     */
    String message() default "失败";

    /**
     * 限流规则 (规则不可变,可多规则)
     */
    RateRule[] rules() default {};

    /**
     * 防重复提交值
     */
    boolean preventDuplicate() default false;

    /**
     * 防重复提交默认值
     */
    RateRule preventDuplicateRule() default @RateRule(count = 1, time = 5);
}

限流规则,默认60s允许10次,支持配置多个规则

@Target(ElementType.ANNOTATION_TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RateRule {
    /**
     * 限流次数
     */
    long count() default 10;

    /**
     * 限流时间
     */
    long time() default 60;

    /**
     * 限流时间单位
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

}

支持IP、客户userId、全局类型,也可以自定义增加其他类型

public enum LimitTypeEnum {
    IP,
    USER_ID,
    GLOBAL;
}
@Aspect
@Component
@Slf4j
public class RateLimiterAop {
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    private final DefaultRedisScript<Boolean> rateLimitScript = new DefaultRedisScript<>();
    private final static String RATE_LUA_NAME = "Redis限流脚本";
    private final static String RATE_LUA = "-- 1. 获取参数\n" +
    "local key = KEYS[1]\n" +
    "local currentTime = KEYS[2]\n" +
    "-- 2. 以数组最大值为 ttl 最大值\n" +
    "local expireTime = -1;\n" +
    "-- 3. 遍历数组查看是否越界\n" +
    "for i = 1, #ARGV, 2 do\n" +
    "    local rateRuleCount = tonumber(ARGV[i])\n" +
    "    local rateRuleTime = tonumber(ARGV[i + 1])\n" +
    "    -- 3.1 判断在单位时间内访问次数\n" +
    "    local count = redis.call('ZCOUNT', key, currentTime - rateRuleTime, currentTime)\n" +
    "    -- 3.2 判断是否超过规定次数\n" +
    "    if tonumber(count) >= rateRuleCount then\n" +
    "        return true\n" +
    "    end\n" +
    "    -- 3.3 判断元素最大值,设置为最终过期时间\n" +
    "    if rateRuleTime > expireTime then\n" +
    "        expireTime = rateRuleTime\n" +
    "    end\n" +
    "end\n" +
    "-- 4. 更新缓存过期时间\n" +
    "redis.call('PEXPIRE', key, expireTime)\n" +
    "-- 5. 删除最大时间限度之前的数据,防止数据过多\n" +
    "redis.call('ZREMRANGEBYSCORE', key, 0, currentTime - expireTime)\n" +
    "-- 6. redis 中添加当前时间  ( 解决多个线程在同一毫秒添加相同 value 导致 Redis 漏记的问题 )\n" +
    "-- 6.1 maxRetries 最大重试次数 retries 重试次数\n" +
    "local maxRetries = 5\n" +
    "local retries = 0\n" +
    "while true do\n" +
    "    local result = redis.call('ZADD', key, currentTime, currentTime)\n" +
    "    if result == 1 then\n" +
    "        -- 6.2 添加成功则跳出循环\n" +
    "        break\n" +
    "    else\n" +
    "        -- 6.3 未添加成功则 value + 1 再次进行尝试\n" +
    "        retries = retries + 1\n" +
    "        if retries >= maxRetries then\n" +
    "            -- 6.4 超过最大尝试次数 采用添加随机数策略\n" +
    "            local random_value = math.random(1, 1000)\n" +
    "            currentTime = currentTime + random_value\n" +
    "        else\n" +
    "            currentTime = currentTime + 1\n" +
    "        end\n" +
    "    end\n" +
    "end\n" +
    " \n" +
    "return false";

    @PostConstruct
    public void loadScript() {
        rateLimitScript.setScriptText(RATE_LUA);
        rateLimitScript.setResultType(Boolean.class);
        loadRedisScript(rateLimitScript, RATE_LUA_NAME);
    }

    private void loadRedisScript(DefaultRedisScript<Boolean> redisScript, String luaName) {
        try {
            List<Boolean> results = redisTemplate.getConnectionFactory().getConnection()
            .scriptExists(redisScript.getSha1());
            if (Boolean.FALSE.equals(results.get(0))) {
                String sha = redisTemplate.getConnectionFactory().getConnection()
                        .scriptLoad(redisScript.getScriptAsString().getBytes(StandardCharsets.UTF_8));
                log.info("预加载lua脚本成功:{}, sha=[{}]", luaName, sha);
            }
        } catch (Exception e) {
            log.error("预加载lua脚本异常:{}", luaName, e);
        }
    }

    /**
     * 限流
     * @param joinPoint   joinPoint
     * @param rateLimiter 限流注解
     */
    @Before(value = "@annotation(rateLimiter)")
    public void boBefore(JoinPoint joinPoint, RateLimiter rateLimiter) throws Exception {
        // 1. 生成 key
        String key = getCombineKey(rateLimiter, joinPoint);
        // 2. 执行脚本返回是否限流  lua脚本参数1为key、参数2为当前时间
        Boolean flag = redisTemplate.execute(rateLimitScript,
                List.of(key, String.valueOf(System.currentTimeMillis())),
                (Object[]) getRules(rateLimiter));
        // 3. 判断是否限流
        if (Boolean.TRUE.equals(flag)) {
            ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            String ip = requestAttributes.getRequest().getHeader("ip");
            log.error("ip: '{}' 拦截到一个请求 RedisKey: '{}'", ip, key);
            throw new Exception(rateLimiter.message());
        }
    }

    /**
     * 通过 rateLimiter 和 joinPoint 拼接  prefix : ip / userId : classSimpleName - methodName
     *
     * @param rateLimiter 提供 prefix
     * @param joinPoint   提供 classSimpleName : methodName
     * @return
     */
    public String getCombineKey(RateLimiter rateLimiter, JoinPoint joinPoint) {
        StringBuffer key = new StringBuffer(rateLimiter.key());
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        String userId = requestAttributes.getRequest().getHeader("userId");
        String ip = requestAttributes.getRequest().getHeader("ip");
        // 不同限流类型使用不同的前缀
        switch (rateLimiter.limitType()) {
            // XXX 可以新增通过参数指定参数进行限流
            case IP:
                key.append(ip).append(":");
                break;
            case USER_ID:
                key.append(userId).append(":");
                break;
            case GLOBAL:
                break;
        }
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        key.append(targetClass.getSimpleName()).append("-").append(method.getName());
        return key.toString();
    }

    /**
     * 获取规则
     *
     * @param rateLimiter 获取其中规则信息
     * @return
     */
    private Long[] getRules(RateLimiter rateLimiter) {
        int capacity = rateLimiter.rules().length << 1;
        // 1. 构建 args
        Long[] args = new Long[rateLimiter.preventDuplicate() ? capacity + 2 : capacity];
        // 3. 记录数组元素
        int index = 0;
        // 2. 判断是否需要添加防重复提交到redis进行校验
        if (rateLimiter.preventDuplicate()) {
            RateRule preventRateRule = rateLimiter.preventDuplicateRule();
            args[index++] = preventRateRule.count();
            args[index++] = preventRateRule.timeUnit().toMillis(preventRateRule.time());
        }
        RateRule[] rules = rateLimiter.rules();
        for (RateRule rule : rules) {
            args[index++] = rule.count();
            args[index++] = rule.timeUnit().toMillis(rule.time());
        }
        return args;
    }
}

测试Controller

@RestController
public class RateLimiterTest {
    @RequestMapping("/rateLimiter")
    @RateLimiter(
        // 60秒内只能访问10次,支持配置多个RateRule
        rules = {@RateRule(count = 10, time = 60, timeUnit = TimeUnit.SECONDS),}
    )
    public void rateTest() {
        System.out.println("测试RateLimiter");
    }
}

使用postman访问/rateLimiter接口,在http请求头添加参数ip(默认使用ip限流),当60s内超过10次时会被限流,打印限流日志
在这里插入图片描述

  • 10
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值