SpringBoot整合Redis + Lua脚本实现限流

SpringBoot整合Redis + Lua脚本实现限流

1. 引入依赖

注意:SpringBoot版本为2.2.6.RELEASE

<dependencies>
    <!-- Spring Boot 依赖 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
        <!-- redis 依赖 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-redis</artifactId>
    </dependency>
        <!-- aop 依赖 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-aop</artifactId>
    </dependency>
        <!-- guava 依赖 -->
    <dependency>
        <groupId>com.google.guava</groupId>
        <artifactId>guava</artifactId>
        <version>21.0</version>
    </dependency>
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-lang3</artifactId>
    </dependency>

    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <optional>true</optional>
    </dependency>
</dependencies>

2. 配置Redis

spring:
  redis:
    host: 127.0.0.1
    port: 6379

3. 定义RedisTemplate

@Configuration
public class RedisLimiterHelper {

    @Bean
    public RedisTemplate<String, Serializable> RedisTemplate(LettuceConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Serializable> template = new RedisTemplate<>();
        template.setConnectionFactory(redisConnectionFactory);
        Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper objectMapper = new ObjectMapper();
        serializer.setObjectMapper(objectMapper);
        template.setDefaultSerializer(serializer);
        template.setKeySerializer(serializer);
        template.setValueSerializer(serializer);
        template.setHashKeySerializer(serializer);
        template.setHashValueSerializer(serializer);
        template.afterPropertiesSet();
        return template;
    }
}

4. 定义限流枚举类

public enum LimitType {
    /**
     * 自定义key
     */
    CUSTOMER,
    /**
     * 请求IP
     */
    IP
}

5. 定义限流注解

@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Limit {
    /** 名称 */
    String name() default "";
    /** key */
    String key() default "";
    /** key的前缀 */
    String prefix() default "";
    /** 时间范围(秒)*/
    int period();
    /** 单位时间内限制次数 */
    int count();
    /** 限流类型 */
    LimitType limitType() default LimitType.CUSTOMER;
}

6. 定义限流拦截器

@Aspect
@Configuration
@Slf4j
public class LimitInterceptor {

    public static final String UNKNOW_KEY = "unknown";

    private final RedisTemplate<String, Serializable> redisTemplate;
    
    @Autowired
    public LimitInterceptor(RedisTemplate<String, Serializable> limitRedisTemplate) {
        this.redisTemplate = limitRedisTemplate;
    }

    @Around("execution(public * *(..)) && @annotation(com.hmds.redisdemo.config.Limit))")
    public Object interceptor(ProceedingJoinPoint point) {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Limit limit = method.getAnnotation(Limit.class);
        LimitType limitType = limit.limitType();

        String name = limit.name();
        String key;
        int limitPeriod = limit.period();
        int limitCount = limit.count();
        switch (limitType) {
            case IP:
                key = getIpAddr();
                break;
            case CUSTOMER:
                key = limit.key();
                break;
                default:
                    key = StringUtils.upperCase(method.getName());
                    break;
        }
        ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limit.prefix(), key));
        try {
            String luaScript = buildLuaScript();
            RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
            // Number count = stringRedisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
            Number count = redisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
            log.info("Access try count:{}, name:{}, key:{}", count, name, key);
            if (count != null && count.intValue() <= limitCount) {
                return point.proceed();
            } else {
                throw new RuntimeException("You have been dragged into the blacklist");
            }
        }catch (Throwable e) {
            if (e instanceof RuntimeException) {
                log.error("LimitInterceptor error", e);
                throw new RuntimeException(e);
            }
            throw new RuntimeException("server exception");
        }
    }

    /**
     * redis lua限流脚本
     * @return
     */
    private String buildLuaScript() {
        StringBuilder lua = new StringBuilder();
        lua.append("local c");
        lua.append("\nc = redis.call('get',KEYS[1])");
        // 调用不超过最大值,则直接返回
        lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
        lua.append("\nreturn c;");
        lua.append("\nend");
        // 执行计算器自加
        lua.append("\nc = redis.call('incr',KEYS[1])");
        lua.append("\nif tonumber(c) == 1 then");
        // 从第一次调用开始限流,设置对应键值的过期
        lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");
        lua.append("\nend");
        lua.append("\nreturn c;");
        return lua.toString();
    }
    
    /**
     * 获取IP地址
     * @return
     */
    public String getIpAddr() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }
}

7. 接口

@RestController
@Slf4j
public class LimiterController {
    private static final AtomicInteger TEST_1 = new AtomicInteger();
    private static final AtomicInteger TEST_2 = new AtomicInteger();
    private static final AtomicInteger TEST_3 = new AtomicInteger();

    @Autowired
    private RedisTemplate<String, Serializable> redisTemplate;

    @Limit(key = "test", period = 10, count = 3)
    @GetMapping("/test1")
    public int test1() {
        return TEST_1.incrementAndGet();
    }

    @Limit(key = "customer_test", period = 10, count = 3, limitType = LimitType.CUSTOMER)
    @GetMapping("/test2")
    public int test2() {
        return TEST_2.incrementAndGet();
    }

    @Limit(key = "ip_test", period = 10, count = 3, limitType = LimitType.IP)
    @GetMapping("/test3")
    public int test3() {
        return TEST_3.incrementAndGet();
    }
}

8. 测试

在这里插入图片描述

10秒内,限制3次,超过3次抛出异常提醒
在这里插入图片描述

  • 8
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值