Redis+LUA脚本结合AOP实现限流

1、demo结构

在这里插入图片描述

2、自定义接口

通过自定义接口标注需要限流的接口

/**
 * redis限流自定义注解
 * @author zyw
 */
//注解的保留位置,RUNTIME表示这种类型的Annotations将被JVM保留,所以他们能在运行时被JVM或其他使用反射机制的代码所读取和使用。
@Retention(RetentionPolicy.RUNTIME)
//说明注解的作用目标,METHOD表示用来修饰方法
@Target({ElementType.METHOD})
//说明该注解将被包含在javadoc中
@Documented
public @interface RedisLimit {
    /**
     * 资源的key,唯一
     * 作用:不同的接口,不同的流量控制
     */
    String key() default "";

    /**
     * 最多的访问限制次数
     */
    long permitsPerSecond() default 2;

    /**
     * 过期时间也可以理解为单位时间,单位秒,默认60
     */
    long expire() default 60;


    /**
     * 得不到令牌的提示语
     */
    String msg() default "系统繁忙,请稍后再试.";
}

3、编写写LUA脚本

通过Lua脚本动态实现动态的创建redis缓存

--获取KEY
local key = KEYS[1]

local limit = tonumber(ARGV[1])

local curentLimit = tonumber(redis.call('get', key) or "0")

if curentLimit + 1 > limit
then return 0
else
    -- 自增长 1
    redis.call('INCRBY', key, 1)
    -- 设置过期时间
    redis.call('EXPIRE', key, ARGV[2])
    return curentLimit + 1
end

4、通过AOP切面识别需要限流的接口

编写切面

  • 1 定义一个类,该类添加了@Component、@Aspect注解
  • 2 定义切点(切点定义方式可参考《Spring AOP配置 之 @PointCut注解》)
  • 3 配置增强,给方法添加@Before、@After、@AfterReturning、@AfterThrowing、@Around等增强配置。

AOP通知类型

  • @Around 环绕通知
  • @Before 通知执行
  • @Before 通知执行结束
  • @Around 环绕通知执行结束
  • @After 后置通知执行了
  • @AfterReturning 第一个后置返回通知后执行
/**
 * Limit AOP
 */
@Slf4j
@Aspect
@Component
public class RedisLimitAop {

    @Autowired
    private StringRedisTemplate stringRedisTemplate;


    @Pointcut("@annotation(com.example.redislimit.aop.RedisLimit)")
    private void check() {

    }

    private DefaultRedisScript<Long> redisScript;

    @PostConstruct
    public void init() {
        redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Long.class);
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("rateLimiter.lua")));
    }


    @Before("check()")
    public void before(JoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        // 请求对象
        ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest servletRequest = sra.getRequest();
        //拿到RedisLimit注解,如果存在则说明需要限流
        RedisLimit redisLimit = method.getAnnotation(RedisLimit.class);

        if (redisLimit != null) {
            //获取redis的key
            String key = redisLimit.key();
            String className = method.getDeclaringClass().getName();
            String name = method.getName();

            String limitKey = key + className + method.getName();

            log.info(limitKey);

            if (StringUtils.isEmpty(key)) {
                throw new RedisLimitException("key cannot be null");
            }

            long limit = redisLimit.permitsPerSecond();

            long expire = redisLimit.expire();

            List<String> keys = new ArrayList<>();
            keys.add(key);

            Long count = stringRedisTemplate.execute(redisScript, keys, String.valueOf(limit), String.valueOf(expire));

            log.info("Access try count is {} for key={}", count, key);

            if (count != null && count == 0) {
                log.debug("获取key失败,key为{}", key);
                throw new RedisLimitException(redisLimit.msg());
            }
        }

    }

5、Redis限流自定义异常构建

Redis限流自定义异常
/**
 * Redis限流自定义异常
 * @date 2023/3/10 21:43
 */
public class RedisLimitException extends RuntimeException{
    public RedisLimitException(String msg) {
        super( msg );
    }
}

声明这个类为全局异常处理器
@RestControllerAdvice// 声明这个类为全局异常处理器
public class GlobalExceptionHandler {


    @ExceptionHandler(RedisLimitException.class) // 声明当前方法要处理的异常类型
    public ResultInfo handlerCustomException(RedisLimitException e) {
        //1. 打印日志
//        e.printStackTrace();

        //2. 给前端提示
        return ResultInfo.error(e.getMessage());
    }


    //非预期异常 对于他们,我们直接捕获,捕获完了,记录日志, 给前端一个假提示
    @ExceptionHandler(Exception.class)
    public ResultInfo handlerException(Exception e) {
        //1. 打印日志
        e.printStackTrace();

        //2. 给前端提示
        return ResultInfo.error("当前系统异常");
    }
}
专属日志
@Getter
@Setter
public class ResultInfo<T> {

    private String message;
    private String code;
    private T data;


    public ResultInfo(String message, String code, T data) {
        this.message = message;
        this.code = code;
        this.data = data;
    }

    public static ResultInfo error(String message) {
        return new ResultInfo(message,"502",null);
    }

}

6、流量限制器

RateLimiter
public class RateLimiter {
    private static final Logger log = LoggerFactory.getLogger(RateLimiter.class);
    // 为每个api在内存中存储限流计数器
    private ConcurrentHashMap<String, RateLimitAlg> counters = new ConcurrentHashMap<>();
    private RateLimitRule rule;
    public RateLimiter() {
        // 将限流规则配置文件ratelimiter-rule.yaml中的内容读取到RuleConfig中
        InputStream in = null;
        RuleConfig ruleConfig = null;
        try {
            in = this.getClass().getResourceAsStream("/ratelimiter-rule.yaml");
            if (in != null) {
                Yaml yaml = new Yaml();
                ruleConfig = yaml.loadAs(in, RuleConfig.class);
            }
        } finally {
            if (in != null) {
                try {
                    in.close();
                } catch (IOException e) {
                    log.error("close file error:{}", e);
                }
            }
        }
        // 将限流规则构建成支持快速查找的数据结构RateLimitRule
        this.rule = new RateLimitRule(ruleConfig);
    }
    public boolean limit(String appId, String url) throws Exception {
        ApiLimit apiLimit = rule.getLimit(appId, url);
        if (apiLimit == null) {
            return true;
        }
        // 获取api对应在内存中的限流计数器(rateLimitCounter)
        String counterKey = appId + ":" + apiLimit.getApi();
        RateLimitAlg rateLimitCounter = counters.get(counterKey);
        if (rateLimitCounter == null) {
            RateLimitAlg newRateLimitCounter = new RateLimitAlg(apiLimit.getLimit());
            rateLimitCounter = counters.putIfAbsent(counterKey, newRateLimitCounter);
            if (rateLimitCounter == null) {
                rateLimitCounter = newRateLimitCounter;
            }
        }
        // 判断是否限流
        return rateLimitCounter.tryAcquire();
    }
}
RateLimitAlg
public class RateLimitAlg {
    /* timeout for {@code Lock.tryLock() }. */
    private static final long TRY_LOCK_TIMEOUT = 200L;  // 200ms.
    private Stopwatch stopwatch;
    private AtomicInteger currentCount = new AtomicInteger(0);
    private final int limit;
    private Lock lock = new ReentrantLock();
    public RateLimitAlg(int limit) {
        this(limit, Stopwatch.createStarted());
    }
    @VisibleForTesting
    protected RateLimitAlg(int limit, Stopwatch stopwatch) {
        this.limit = limit;
        this.stopwatch = stopwatch;
    }
    public boolean tryAcquire() throws Exception {
        int updatedCount = currentCount.incrementAndGet();
        if (updatedCount <= limit) {
            return true;
        }
        try {
            if (lock.tryLock(TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS)) {
                try {
                    if (stopwatch.elapsed(TimeUnit.MILLISECONDS) > TimeUnit.SECONDS.toMillis(1)) {
                        currentCount.set(0);
                        stopwatch.reset();
                    }
                    updatedCount = currentCount.incrementAndGet();
                    return updatedCount <= limit;
                } finally {
                    lock.unlock();
                }
            } else {
                throw new Exception("tryAcquire() wait lock too long:" + TRY_LOCK_TIMEOUT + "ms");
            }
        } catch (InterruptedException e) {
            throw new Exception("tryAcquire() is interrupted by lock-time-out.", e);
        }
    }
}
ApiLimit
public class ApiLimit {
    private static final int DEFAULT_TIME_UNIT = 1; // 1 second
    private String api;
    private int limit;
    private int unit = DEFAULT_TIME_UNIT;
    public ApiLimit() {}
    public ApiLimit(String api, int limit) {
        this(api, limit, DEFAULT_TIME_UNIT);
    }
    public ApiLimit(String api, int limit, int unit) {
        this.api = api;
        this.limit = limit;
        this.unit = unit;
    }

    public String getApi() {
        return api;
    }

    public void setApi(String api) {
        this.api = api;
    }

    public int getLimit() {
        return limit;
    }

    public void setLimit(int limit) {
        this.limit = limit;
    }

    public int getUnit() {
        return unit;
    }

    public void setUnit(int unit) {
        this.unit = unit;
    }
}
RateLimitRule
public class RateLimitRule {
    public RateLimitRule(RuleConfig ruleConfig) {
        //...
    }
    public ApiLimit getLimit(String appId, String api) {
        return null;
    }
}
RuleConfig
public class RuleConfig {
    private List<AppRuleConfig> configs;
    public List<AppRuleConfig> getConfigs() {
        return configs;
    }
    public void setConfigs(List<AppRuleConfig> configs) {
        this.configs = configs;
    }
    public static class AppRuleConfig {
        private String appId;
        private List<ApiLimit> limits;
        public AppRuleConfig() {}
        public AppRuleConfig(String appId, List<ApiLimit> limits) {
            this.appId = appId;
            this.limits = limits;
        }

        public String getAppId() {
            return appId;
        }

        public void setAppId(String appId) {
            this.appId = appId;
        }

        public List<ApiLimit> getLimits() {
            return limits;
        }

        public void setLimits(List<ApiLimit> limits) {
            this.limits = limits;
        }
    }
}

7、Guavalimit

限流自定义异常
/**
 * @Description 限流自定义异常
 * @Author zyw
 * @Date 2019/8/7 16:01
 */
public class LimitAccessException extends RuntimeException {

    private static final long serialVersionUID = -3608667856397125671L;

    public LimitAccessException(String message) {
        super(message);
    }
}
限流key类型枚举
/**
 * @Description 限流key类型枚举
 * @Author zyw
 * @Date 2020/5/17 14:28
 */
public enum LimitKeyTypeEnum {

    IPADDR("IPADDR", "根据Ip地址来限制"),
    CUSTOM("CUSTOM", "自定义根据业务唯一码来限制,需要在请求参数中添加 String limitKeyValue");

    private String keyType;
    private String desc;

    LimitKeyTypeEnum(String keyType, String desc) {
        this.keyType = keyType;
        this.desc = desc;
    }

    public String getKeyType() {
        return keyType;
    }

    public String getDesc() {
        return desc;
    }
}

自定义限流注解

/**
 * @Description 自定义限流注解
 * @Author zyw
 * @Date 2020/5/17 11:49
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LxRateLimit {

    //资源名称
    String name() default "默认资源";

    //限制每秒访问次数,默认为3次
    double perSecond() default 3;

    /**
     * 限流Key类型
     * 自定义根据业务唯一码来限制需要在请求参数中添加 String limitKeyValue
     */
    LimitKeyTypeEnum limitKeyType() default LimitKeyTypeEnum.IPADDR;

}
基于Guava cache缓存存储实现限流切面
/**
 * @Description 基于Guava cache缓存存储实现限流切面
 * @Author 张佑威
 * @Date 2020/5/17 11:51
 */
@Slf4j
@Aspect
@Component
public class LxRateLimitAspect {

    /**
     * 缓存
     * maximumSize 设置缓存个数
     * expireAfterWrite 写入后过期时间
     */
    private static LoadingCache<String, RateLimiter> limitCaches = CacheBuilder.newBuilder()
            .maximumSize(1000)
            .expireAfterWrite(1, TimeUnit.DAYS)
            .build(new CacheLoader<String, RateLimiter>() {
                @Override
                public RateLimiter load(String key) throws Exception {
                    double perSecond = LxRateLimitUtil.getCacheKeyPerSecond(key);
                    return RateLimiter.create(perSecond);
                }
            });

    /**
     * 切点
     * 通过扫包切入 @Pointcut("execution(public * com.ycn.springcloud.*.*(..))")
     * 带有指定注解切入 @Pointcut("@annotation(com.ycn.springcloud.annotation.LxRateLimit)")
     */
    @Pointcut("@annotation(com.example.guavalimit.limit.LxRateLimit)")
    public void pointcut() {
    }

    @Around("pointcut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        log.info("限流拦截到了{}方法...", point.getSignature().getName());
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        if (method.isAnnotationPresent(LxRateLimit.class)) {
            String cacheKey = LxRateLimitUtil.generateCacheKey(method, request);
            RateLimiter limiter = limitCaches.get(cacheKey);
            if (!limiter.tryAcquire()) {
                throw new LimitAccessException("【限流】这位小同志的手速太快了");
            }
        }
        return point.proceed();
    }
}

限流工具类

/**
 * @Description 限流工具类
 * @Author zyw
 * @Date 2020/5/17 15:37
 */
public class LxRateLimitUtil {


    /**
     * 获取唯一key根据注解类型
     * <p>
     * 规则 资源名:业务key:perSecond
     *
     * @param method
     * @param request
     * @return
     */
    public static String generateCacheKey(Method method, HttpServletRequest request) {
        //获取方法上的注解
        LxRateLimit lxRateLimit = method.getAnnotation(LxRateLimit.class);
        StringBuffer cacheKey = new StringBuffer(lxRateLimit.name() + ":");
        switch (lxRateLimit.limitKeyType()) {
            case IPADDR:
                cacheKey.append(getIpAddr(request) + ":");
                break;
            case CUSTOM:
                String limitKeyValue = request.getParameter("limitKeyValue");
                if (StringUtils.isEmpty(limitKeyValue)) {
                    throw new LimitAccessException("【" + method.getName() + "】自定义业务Key缺少参数String limitKeyValue,或者参数为空");
                }
                cacheKey.append(limitKeyValue + ":");
                break;
        }
        cacheKey.append(lxRateLimit.perSecond());
        return cacheKey.toString();
    }

    /**
     * 获取缓存key的限制每秒访问次数
     * <p>
     * 规则 资源名:业务key:perSecond
     *
     * @param cacheKey
     * @return
     */
    public static double getCacheKeyPerSecond(String cacheKey) {
        String perSecond = cacheKey.split(":")[2];
        return Double.parseDouble(perSecond);
    }

    /**
     * 获取客户端IP地址
     *
     * @param request 请求
     * @return
     */
    public static String getIpAddr(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if ("0:0:0:0:0:0:0:1".equals(ip)) {
            ip = "127.0.0.1";
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
            if ("127.0.0.1".equals(ip)) {
                //根据网卡取本机配置的IP
                InetAddress inet = null;
                try {
                    inet = InetAddress.getLocalHost();
                } catch (UnknownHostException e) {
                    e.printStackTrace();
                }
                ip = inet.getHostAddress();
            }
        }
        // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
        if (ip != null && ip.length() > 15) {
            if (ip.indexOf(",") > 0) {
                ip = ip.substring(0, ip.indexOf(","));
            }
        }
        return ip;
    }
}

8、测试控制层

@RestController
public class TestController {

    @GetMapping("/test")
    public String getTest(){
        return "jxj";
    }

    @GetMapping("/guavalimit")
    @LxRateLimit
    public String guavaLimit(){
        return "ok";
    }

    @GetMapping("/redislimit")
    @RedisLimit(key = "redis-limit:test", permitsPerSecond = 2, expire = 1, msg = "当前排队人数较多,请稍后再试!")
    public String redisLimit(){
        return "ok";
    }
}

9、测试结果

Redis+LUA脚本

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
实现分布式限流可以使用 RedisLua 脚本来完成。以下是可能的实现方案: 1. 使用 Redis 的 SETNX 命令来实现基于令牌桶算法的限流 令牌桶算法是一种常见的限流算法,它可以通过令牌的放置和消耗来控制流量。在 Redis 中,我们可以使用 SETNX 命令来实现令牌桶算法。 具体实现步骤如下: - 在 Redis 中创建一个有序集合,用于存储令牌桶的令牌数量和时间戳。 - 每当一个请求到达时,我们首先获取当前令牌桶中的令牌数量和时间戳。 - 如果当前时间戳与最后一次请求的时间戳之差大于等于令牌桶中每个令牌的发放时间间隔,则将当前时间戳更新为最后一次请求的时间戳,并且将令牌桶中的令牌数量增加相应的数量,同时不超过最大容量。 - 如果当前令牌桶中的令牌数量大于等于请求需要的令牌数量,则返回 true 表示通过限流,将令牌桶中的令牌数量减去请求需要的令牌数量。 - 如果令牌桶中的令牌数量不足,则返回 false 表示未通过限流。 下面是使用 RedisLua 脚本实现令牌桶算法的示例代码: ```lua -- 限流的 key local key = KEYS[1] -- 令牌桶的容量 local capacity = tonumber(ARGV[1]) -- 令牌的发放速率 local rate = tonumber(ARGV[2]) -- 请求需要的令牌数量 local tokens = tonumber(ARGV[3]) -- 当前时间戳 local now = redis.call('TIME')[1] -- 获取当前令牌桶中的令牌数量和时间戳 local bucket = redis.call('ZREVRANGEBYSCORE', key, now, 0, 'WITHSCORES', 'LIMIT', 0, 1) -- 如果令牌桶为空,则初始化令牌桶 if not bucket[1] then redis.call('ZADD', key, now, capacity - tokens) return 1 end -- 计算当前令牌桶中的令牌数量和时间戳 local last = tonumber(bucket[2]) local tokensInBucket = tonumber(bucket[1]) -- 计算时间间隔和新的令牌数量 local timePassed = now - last local newTokens = math.floor(timePassed * rate) -- 更新令牌桶 if newTokens > 0 then tokensInBucket = math.min(tokensInBucket + newTokens, capacity) redis.call('ZADD', key, now, tokensInBucket) end -- 检查令牌数量是否足够 if tokensInBucket >= tokens then redis.call('ZREM', key, bucket[1]) return 1 else return 0 end ``` 2. 使用 RedisLua 脚本实现基于漏桶算法的限流 漏桶算法是另一种常见的限流算法,它可以通过漏桶的容量和漏水速度来控制流量。在 Redis 中,我们可以使用 Lua 脚本实现漏桶算法。 具体实现步骤如下: - 在 Redis 中创建一个键值对,用于存储漏桶的容量和最后一次请求的时间戳。 - 每当一个请求到达时,我们首先获取当前漏桶的容量和最后一次请求的时间戳。 - 计算漏水速度和漏水的数量,将漏桶中的容量减去漏水的数量。 - 如果漏桶中的容量大于等于请求需要的容量,则返回 true 表示通过限流,将漏桶中的容量减去请求需要的容量。 - 如果漏桶中的容量不足,则返回 false 表示未通过限流。 下面是使用 RedisLua 脚本实现漏桶算法的示例代码: ```lua -- 限流的 key local key = KEYS[1] -- 漏桶的容量 local capacity = tonumber(ARGV[1]) -- 漏水速度 local rate = tonumber(ARGV[2]) -- 请求需要的容量 local size = tonumber(ARGV[3]) -- 当前时间戳 local now = redis.call('TIME')[1] -- 获取漏桶中的容量和最后一次请求的时间戳 local bucket = redis.call('HMGET', key, 'capacity', 'last') -- 如果漏桶为空,则初始化漏桶 if not bucket[1] then redis.call('HMSET', key, 'capacity', capacity, 'last', now) return 1 end -- 计算漏水的数量和漏桶中的容量 local last = tonumber(bucket[2]) local capacityInBucket = tonumber(bucket[1]) local leak = math.floor((now - last) * rate) -- 更新漏桶 capacityInBucket = math.min(capacity, capacityInBucket + leak) redis.call('HSET', key, 'capacity', capacityInBucket) redis.call('HSET', key, 'last', now) -- 检查容量是否足够 if capacityInBucket >= size then return 1 else return 0 end ``` 以上是使用 RedisLua 脚本实现分布式限流的两种方案,可以根据实际需求选择适合的方案。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柚几哥哥

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值