基于Spring Aop + Redis实现分布式多维度前置后置限流

说明

在实际场景,比如发送短信验证码、刷评论是需要一定限流控制的,其中限流又可以分为前置限流,后置限流。

所谓前置限流即为调用目标接口前校验,无论被调用的接口是否发生异常或者是否返回预期值;

后置限流是调用接口后,可以根据指定的Condition判断是否记录次数,Condition支持EL表达式。

本文通过Spring Aop + 自定义注解 + Redis 分布式锁 + Redi lua脚本实现前、后置限流。并且提供用户维度、IP维度和自定义EL表达式Key多维度限流

配置Maven依赖

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>

配置Redis

配置文件application.yml

redis:
    host: 192.168.0.1
    port: 6379
    password:
    lettuce:
      pool:
        time-between-eviction-runs: 30s
        max-active: 10
        max-wait: -1ms
        min-idle: 0
        max-idle: 8

其中配置time-between-eviction-runs表示每隔多长时间清理redis连接池中的空闲链接到min-idle配置的数量。最好配置,否则使用lettuce时候会经常出现:远程主机强制关闭一个现有的链接...

Confirguration配置类

@Slf4j
@EnableCaching
@ConditionalOnClass(RedisOperations.class)
public class RedisConfiguration extends CachingConfigurerSupport {
    /**
     * RedisTemplate<String, Object> 走的是RedisTemplate
     * RedisTemplate<String, String> 走的是StringRedisTemplate
     */
    @Bean
    @Primary
    public RedisTemplate redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);
        // 指定序列化输入的类型,保证反序列化出来一个java对象
        Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        //objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        //防止对象存入后,解析出来时是各Map
        objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.WRAPPER_ARRAY);
        serializer.setObjectMapper(objectMapper);
        RedisSerializer<String> stringRedisSerializer = new StringRedisSerializer();
        template.setKeySerializer(stringRedisSerializer);
        template.setHashKeySerializer(stringRedisSerializer);
        template.setValueSerializer(serializer);
        template.setHashValueSerializer(serializer);
        template.afterPropertiesSet();
        return template;
    }

    @Bean
    public RateLimitAspect rateLimitAspect(RedisTemplate<String, Object> redisTemplate, RedisLockService redisLock) {
        return new RateLimitAspect(redisTemplate, redisLock);
    }

    @Bean
    public RedisLockService redisLockService(RedisTemplate<String, Long> redisTemplate) {
        return new RedisLockService(redisTemplate);
    }

    @Bean
    public RedisLockAspect redisLockAspect(RedisLockService redisLockService) {
        return new RedisLockAspect(redisLockService);
    }
}

定义限流注解@KeyRateLimiter和@PostKeyRateLimiter

前面先创建枚举类型RateLimiterType,即定义支持哪几种限流模式

public enum RateLimiterType {
    /**
     * 客户端ip
     */
    CLIENT_IP,

    /**
     * 用户
     */
    USER,

    /**
     * 自定义模式,需要指定key
     */
    CUSTOM
}

前置限流注解:KeyRateLimiter,其中key支持EL表达式解析,可以获取到目标方法上面的参数作为Key值;另一个type可以指定使用哪种限流维度。

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(KeyRateLimiters.class)
public @interface KeyRateLimiter {
    /**
     * 限流Key,支持Spring el
     *
     * @return Key
     */
    String key() default "";

    /**
     * 每秒令牌数
     *
     * @return 每秒令牌数
     */
    int limit() default 1;

    /**
     * 频率,默认1
     */
    int interval() default 1;

    /**
     * 频率单位,默认秒
     */
    TimeUnit intervalUnit() default TimeUnit.SECONDS;

    /**
     * 限流类型,如果为CUSTOM,需要指定key
     */
    RateLimiterType type() default RateLimiterType.CUSTOM;

    /**
     * 限流拒绝后的消息内容
     */
    String message() default "您的操作过快,请稍后再试!";
}

后置限流注解:PostKeyRateLimiter,与KeyRateLimiter不同的是增加了condition,根据el表达式的bool返回值判断是否计入调用次数。另外,PostKeyRateLimiter的实现方式也不一样,由于调用计数是发生在方法执行完成之后,所以需要结合Redis分布式锁来串行化调用,性能自然比会KeyRateLimiter差一些。两者都是使用Redis pipeline,同时在一个方法上面叠加配置。

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(PostKeyRateLimiters.class)
public @interface PostKeyRateLimiter {
    /**
     * 限流Key,支持Spring el
     *
     * @return Key
     */
    String key() default "";

    /**
     * 每秒令牌数
     *
     * @return 每秒令牌数
     */
    int limit() default 1;

    /**
     * 频率,默认1
     */
    int interval() default 1;

    /**
     * 频率单位,默认秒
     */
    TimeUnit intervalUnit() default TimeUnit.SECONDS;

    /**
     * 限流类型,如果为CUSTOM,需要指定key
     */
    RateLimiterType type() default RateLimiterType.CUSTOM;

    /**
     * 生效表达式(包括取返回值#rtv.code == 200)
     */
    String condition() default "";

    /**
     * 限流拒绝后的消息内容
     */
    String message() default "您的操作过快,请稍后再试!";
}

再来两个组合注解,支持多个使用限流注解同时使用

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface KeyRateLimiters {
    KeyRateLimiter[] value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PostKeyRateLimiters {
    PostKeyRateLimiter[] value();
}

创建配置类,上面PostKeyRateLimiter和KeyRateLimiter最终转为RateLimitConfig实例

@Data
public class RateLimitConfig {
    /**
     * 限流Key
     */
    private String key;

    /**
     * 区间令牌数
     */
    private int limit;

    /**
     * 区间频率
     */
    private int rateInterval;

    /**
     * 频率单位,默认秒
     */
    private TimeUnit intervalUnit;

    /**
     * 限流触发条件,spEL表达式
     */
    private String condition;

    /**
     * 限流类型
     */
    private RateLimiterType type;

    /**
     * 限流拒绝后的消息内容
     */
    private String message;

    public RateLimitConfig(PostKeyRateLimiter keyRateLimiter) {
        this.key = keyRateLimiter.key();
        this.limit = keyRateLimiter.limit();
        this.rateInterval = keyRateLimiter.interval();
        this.intervalUnit = keyRateLimiter.intervalUnit();
        this.message = keyRateLimiter.message();
        this.condition = keyRateLimiter.condition();
        this.type = keyRateLimiter.type();
    }

    public RateLimitConfig(KeyRateLimiter keyRateLimiter) {
        this.key = keyRateLimiter.key();
        this.limit = keyRateLimiter.limit();
        this.rateInterval = keyRateLimiter.interval();
        this.intervalUnit = keyRateLimiter.intervalUnit();
        this.message = keyRateLimiter.message();
        this.type = keyRateLimiter.type();
    }
}

创建限流切面RateLimitAspect

@Slf4j
@Aspect
@RequiredArgsConstructor
public class RateLimitAspect extends AbstractAspect {
    private final RedisTemplate<String, Object> redisTemplate;
    private final RedisLockService redisLock;
    private static RedisScript<Number> rateLuaScript;

    static {
        // 返回0,1形式
        String luaScript = "local current = tonumber(redis.call('get',KEYS[1]) or '0')\n" +
                "if current >= tonumber(ARGV[1]) then\n" +
                "\treturn 0\n" +
                "end\n" +
                "current = redis.call('incr',KEYS[1])\n" +
                "if current == 1 then\n" +
                "\tredis.call('pexpire',KEYS[1],ARGV[2])\n" +
                "end\n" +
                "return 1";
        rateLuaScript = new DefaultRedisScript<>(luaScript, Number.class);
    }

    /**
     * 前置定义切入点
     */
    @Pointcut("@annotation(com.iwork.boot.redis.rt.KeyRateLimiter) " +
            "|| @annotation(com.iwork.boot.redis.rt.KeyRateLimiters)  " +
            "|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiter) " +
            "|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiters)")
    public void frontRateLimiter() {
    }

    @Around("frontRateLimiter()")
    public Object executeFront(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        List<RateLimitConfig> limitConfigs = new ArrayList<>(8);
        List<RateLimitConfig> postLimitConfigs = new ArrayList<>(4);
        KeyRateLimiter keyRateLimiter = method.getAnnotation(KeyRateLimiter.class);
        KeyRateLimiters keyRateLimiters = method.getAnnotation(KeyRateLimiters.class);
        PostKeyRateLimiter postKeyRateLimiter = method.getAnnotation(PostKeyRateLimiter.class);
        PostKeyRateLimiters postKeyRateLimiters = method.getAnnotation(PostKeyRateLimiters.class);

        if (keyRateLimiter != null) {
            limitConfigs.add(new RateLimitConfig(keyRateLimiter));
        }
        if (keyRateLimiters != null && keyRateLimiters.value().length > 0) {
            Stream.of(keyRateLimiters.value()).map(RateLimitConfig::new).forEach(limitConfigs::add);
        }
        if (postKeyRateLimiter != null) {
            postLimitConfigs.add(new RateLimitConfig(postKeyRateLimiter));
        }
        if (postKeyRateLimiters != null && postKeyRateLimiters.value().length > 0) {
            Stream.of(postKeyRateLimiters.value()).map(RateLimitConfig::new).forEach(postLimitConfigs::add);
        }

        // 前置校验
        setKey("rt:front:", joinPoint, limitConfigs);
        Set<String> errMsgSet = validateFront(limitConfigs.toArray(new RateLimitConfig[]{}));
        if (!errMsgSet.isEmpty()) {
            // 此处应该抛出特定异常,通过全局异常拦截处理
            throw new BusinessException(errMsgSet.toString());
        }

        // 后置校验需要上锁
        if (!postLimitConfigs.isEmpty()) {
            // 设置Key
            setKey("rt:post:", joinPoint, postLimitConfigs);
            String key = "locks:" + postLimitConfigs.iterator().next().getKey();
            // 获取锁后执行
            return redisLock.executeWithLock(key, 10, 60, TimeUnit.SECONDS, () -> {
                SessionCallback<Number> callback = new SessionCallback<Number>() {
                    @Override
                    public Number execute(RedisOperations operations) throws DataAccessException {
                        ValueOperations kvValueOperations = operations.opsForValue();
                        for (RateLimitConfig postLimitConfig : postLimitConfigs) {
                            String key1 = postLimitConfig.getKey();
                            kvValueOperations.get(key1);
                        }
                        return null;
                    }
                };
                List<Object> objects = redisTemplate.executePipelined(callback);
                for (int i = 0; i < postLimitConfigs.size(); i++) {
                    Number val = (Number) objects.get(i);
                    RateLimitConfig rateLimitConfig = postLimitConfigs.get(i);
                    if (val != null && val.longValue() >= rateLimitConfig.getLimit()) {
                        errMsgSet.add(rateLimitConfig.getMessage());
                    }
                }

                if (!errMsgSet.isEmpty()) {
                    // 此处应该抛出特定异常,通过全局异常拦截处理
                    throw new BusinessException(errMsgSet.toString());
                }
                try {
                    // 执行业务方法
                    Object proceed = joinPoint.proceed();
                    // 扣减令牌
                    RateLimitConfig[] filterConfigs = postLimitConfigs.stream()
                            .filter(config -> parsePostSpEl(proceed, config))
                            .collect(Collectors.toList())
                            .toArray(new RateLimitConfig[]{});
                    validateFront(filterConfigs);
                    return proceed;
                } catch (BusinessException e) {
                    throw e;
                } catch (Throwable throwable) {
                    throw new BusinessException(throwable);
                }
            });
        }

        return joinPoint.proceed();
    }

    private Set<String> validateFront(RateLimitConfig... rateLimitConfigs) {
        Set<String> errorMsg = new HashSet<>(rateLimitConfigs.length);
        List<Object> objects = redisTemplate.executePipelined(new SessionCallback<Number>() {
            @Override
            public Number execute(RedisOperations operations) throws DataAccessException {
                for (RateLimitConfig limitConfig : rateLimitConfigs) {
                    // 这里不能使用long类型,否则越界 ERR value is not an integer or out of range
                    int period = (int) limitConfig.getIntervalUnit().toMillis(limitConfig.getRateInterval());
                    operations.execute(rateLuaScript, Collections.singletonList(limitConfig.getKey()), limitConfig.getLimit(), period);
                }
                return null;
            }
        });

        for (int i = 0; i < rateLimitConfigs.length; i++) {
            Number val = (Number) objects.get(i);
            // 被限流
            if (val.longValue() == 0L) {
                errorMsg.add(rateLimitConfigs[i].getMessage());
            }
        }

        return errorMsg;
    }

    private void setKey(String prefix, ProceedingJoinPoint joinPoint, List<RateLimitConfig> limitConfigs) {
        for (RateLimitConfig limitConfig : limitConfigs) {
            String key = limitConfig.getKey();
            RateLimiterType type = limitConfig.getType();
            Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
            String methodKey = prefix + parseElKey(joinPoint, limitConfig.getKey());

            // 基于客户端ip
            if (type == RateLimiterType.CLIENT_IP) {
                HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                        .map(ServletRequestAttributes.class::cast)
                        .map(ServletRequestAttributes::getRequest)
                        .orElseThrow(() -> new IllegalStateException("只能在Web环境中获取Request对象!"));
                String clientIP = ServletUtil.getClientIP(request);
                methodKey = methodKey + ":" + clientIP;
            }
            // 基于用户维度
            else if (type == RateLimiterType.USER) {
                String userId = authentication.getPrincipal().toString();
                methodKey = methodKey + ":" + userId;
            }
            // 自定义,key不能为空
            else {
                Assert.hasText(key, "限流Key不能为空!");
            }

            limitConfig.setKey(methodKey);
        }
    }

    private boolean parsePostSpEl(Object val, RateLimitConfig limitConfig) {
        String condition = limitConfig.getCondition();
        if (StringUtils.isBlank(condition) || !condition.contains(EL_PREFIX)) {
            return true;
        }
        StandardEvaluationContext context = new StandardEvaluationContext();
        context.setVariable("rtv", val);
        Expression expression = expressionParser.parseExpression(condition);
        return Optional.ofNullable(expression.getValue(context, Boolean.class)).orElse(true);
    }
}

上面使用Redis pipline、Redis Lock,逻辑不难就不做细讲了,有疑问欢迎提问!

父类代码

public abstract class AbstractAspect {
    protected static ExpressionParser expressionParser = new SpelExpressionParser();
    protected static final String EL_PREFIX = "#";

    protected String getMethodKey(ProceedingJoinPoint joinPoint, String elKey) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        String limitKey = parseElKey(joinPoint, elKey);
        String className = signature.getDeclaringType().getSimpleName();
        String methodName = signature.getName();
        limitKey = "method:" + className + "#" + methodName + "#" + limitKey;

        return limitKey;
    }

    protected String parseElKey(ProceedingJoinPoint joinPoint, String elKey) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        String[] parameterNames = signature.getParameterNames();
        Object[] parameterValues = joinPoint.getArgs();

        // 解析el
        if (elKey.contains(EL_PREFIX)) {
            StandardEvaluationContext context = new StandardEvaluationContext();
            for (int i = 0; i < parameterNames.length; i++) {
                context.setVariable(parameterNames[i], parameterValues[i]);
            }
            // 解析
            Expression expression = expressionParser.parseExpression(elKey);
            return expression.getValue(context, String.class);
        }

        return elKey;
    }

    protected Map<String, Object> getMethodParameters(JoinPoint joinPoint) {
        Map<String, Object> parameters = new LinkedHashMap<>(18);
        if (joinPoint instanceof MethodSignature) {
            //参数值
            Object[] argValues = joinPoint.getArgs();
            if(argValues != null) {
                //参数名称
                String[] argNames = ((MethodSignature) joinPoint.getSignature()).getParameterNames();
                for (int i = 0; i < argNames.length; i++) {
                    parameters.put(argNames[i], argValues[i]);
                }
            }
        }

        return parameters;
    }
}

使用

@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
基于IP的前置限流
 
@KeyRateLimiter(type = RateLimiterType.USER)
基于用户维度的前置限流

@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
自定义Key限流,并且通过返回值code==200才标记有效访问,进行限流
@Slf4j
@RestController
@Api(tags = "系统:系统授权接口")
@RequiredArgsConstructor
public class AuthController {

	@AnonymousAccess
    @ApiOperation("获取验证码")
    @GetMapping(value = "/code")
    @KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
    @KeyRateLimiter(type = RateLimiterType.USER)
    @KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
    public XCloudResponse<Object> getCode(@RequestParam String username) {
        // 省略代码细节
    }
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值