SpringBoot 基于 bucket4j 框架实现限流

SpringBoot 使用 Google 缓存框架 + bucket4j 限流框架实现的基于内存的单机限流,分布式限流可以在网关层使用。

(1)引入依赖

    <dependencies>

        <!-- Google 缓存框架 -->
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>31.1-jre</version>
        </dependency>

        <!--  使用 SpringBoot Aop 框架 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>

        <!-- bucket4j 限流框架 -->
        <dependency>
            <groupId>com.github.vladimir-bukhtoyarov</groupId>
            <artifactId>bucket4j-core</artifactId>
            <version>7.6.0</version>
        </dependency>

        <!-- 注入常用工具类 -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>

        <dependency>
            <groupId>cn.hutool</groupId>
            <artifactId>hutool-all</artifactId>
            <version>5.8.18</version>
        </dependency>

    </dependencies>

(2)定义 @Throttling 和 @Throttlings 接口限流注解

/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description 接口限流注解
 * 注:可以用在方法和类上,可以同时存在多个
 */
@Inherited
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Throttling {

    /**
     * 生成的令牌数量
     */
    int capacity() default 5;

    /**
     * 每单位时间,生成的令牌数量,默认:1 秒内生成 5 个令牌
     */
    long period() default 1;

    /**
     * 时间单位,默认:秒
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

    /**
     * 唯一标识,需要实现 ThrottlingKeyProvider 接口
     */
    Class<? extends ThrottlingKeyProvider> keyProvider();

    /**
     * 缓存 key 的前缀
     * (1)默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流;
     * (2)如果需要要对全局进行限流,需设置固定值,例如想实现全局IP的限流,可以采用固定 prefix + (IP)keyProvider 的模式,相当于提供某一组接口共用一个限流配置
     */
    String prefix() default "";
}
/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description 接口限流注解
 */
@Inherited
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Throttlings {
    Throttling[] value();
}
/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description 限流唯一 key 生成接口
 */
public interface ThrottlingKeyProvider {

    /**
     * 生成限流标识
     *
     * @return
     */
    String generateUniqKey();
}

(3)定义 ThrottlingAop 限流 AOP

/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description AOP切面 抽象基类
 */
public abstract class AbstractAop {

    /**
     * 获取 Method 的指定注解
     *
     * @param joinPoint
     * @return T extends Annotation
     */
    protected <T extends Annotation> T getMethodAnnotation(ProceedingJoinPoint joinPoint, Class<T> clazz) {
        T annotation = null;
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Class<?>[] paramTypeArray = methodSignature.getParameterTypes();
        Method transferMoney = null;
        try {
            transferMoney = joinPoint.getTarget().getClass().getDeclaredMethod(methodSignature.getName(), paramTypeArray);
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        }
        assert transferMoney != null;
        boolean annotationPresent = transferMoney.isAnnotationPresent(clazz);
        if (annotationPresent) {
            annotation = transferMoney.getAnnotation(clazz);
        }
        return annotation;
    }

    /**
     * 获取当前类及其父类包含的指定的注解类型
     *
     * @param clazz
     * @param annotation
     * @param <T>
     * @return
     */
    protected <T extends Annotation> List<T> getInheritClassAnnotation(Class<?> clazz, Class<T> annotation) {
        List<T> list = new ArrayList<>();
        // 获取当前类及其父类包含的指定的T类型注解
        do {
            T annot = getClassAnnotation(clazz, annotation);
            if (annot != null) {
                list.add(annot);
            }
            clazz = clazz.getSuperclass();
        } while (clazz != null);
        return list;
    }

    /**
     * 获取指定类的注解
     *
     * @param clazz
     * @param annotation
     * @param <T>
     * @return
     */
    protected <T extends Annotation> T getClassAnnotation(Class<?> clazz, Class<T> annotation) {
        T res = null;
        boolean annotationPresent = clazz.isAnnotationPresent(annotation);
        if (annotationPresent) {
            res = clazz.getAnnotation(annotation);
        }
        return res;
    }
}
/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description 限流 AOP
 * 注意:基于内存保存数据,缓存的数据量不大,且仅仅支持单体应用限流.
 */
@Aspect
@Component
@Order(1)
public class ThrottlingAop extends AbstractAop {

    /**
     * 基于内存保存数据,缓存的数据量不大,且仅仅支持单体应用.
     */
    public static Cache<String, Bucket> localCache = CacheBuilder.newBuilder()
            // 指定缓存项在给定时间内没有被访问,则回收
            .expireAfterAccess(1, TimeUnit.DAYS)
            // 设置缓存容器的初始容量
            .initialCapacity(200)
            // 设置缓存最大容量,超过之后就会按照LRU最近最少使用算法来移除缓存项
            .maximumSize(10000)
            // 设置并发级别为8,并发级别是指可以同时写缓存的线程数
            .concurrencyLevel(8).build();

    @Resource
    private ApplicationContext applicationContext;

    /**
     * 基于方法上的注解
     */
    @Pointcut("@annotation(cn.yifants.common.extension.aop.throttling.Throttling) || @annotation(cn.yifants.common.extension.aop.throttling.Throttlings)")
    public void methodAnnotation() {
    }

    /**
     * 基于类上的注解
     */
    @Pointcut("@within(cn.yifants.common.extension.aop.throttling.Throttling) || @within(cn.yifants.common.extension.aop.throttling.Throttlings)")
    public void classAnnotation() {
    }

    @Around("methodAnnotation() || classAnnotation()")
    public Object rounding(ProceedingJoinPoint joinPoint) throws Throwable {

        List<Throttling> annotations = new ArrayList<>();

        // 解析 Throttling 注解
        Throttling throttling = getMethodAnnotation(joinPoint, Throttling.class);
        if (throttling != null) {
            annotations.add(throttling);
        }
        annotations.addAll(getInheritClassAnnotation(joinPoint.getTarget().getClass(), Throttling.class));

        // 解析 Throttlings 注解
        Throttlings throttlings = getMethodAnnotation(joinPoint, Throttlings.class);
        if (throttlings != null) {
            annotations.addAll(Arrays.asList(throttlings.value()));
        }
        getInheritClassAnnotation(joinPoint.getTarget().getClass(), Throttlings.class).forEach(x -> {
            annotations.addAll(Arrays.asList(x.value()));
        });

        // 限流处理
        if (CollectionUtil.isNotEmpty(annotations)) {
            // 没有指定 prefix 时,默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流
            String signature = SecureUtil.md5(joinPoint.getSignature().toString());
            annotations.stream().distinct().collect(Collectors.groupingBy(x -> ThrottlingKey.builder().keyProvider(x.keyProvider()).prefix(x.prefix()).build())).forEach((key, val) -> {
                if (StringUtils.isBlank(key.getPrefix())) {
                    key.setPrefix(signature);
                }
                Throttlinghandle(key, val);
            });
        }

        Object[] args = joinPoint.getArgs();

        return joinPoint.proceed(args);
    }

    /**
     * 限流处理
     *
     * @param keyProvider
     * @param list
     */
    private synchronized void Throttlinghandle(ThrottlingKey keyProvider, List<Throttling> list) {

        ThrottlingKeyProvider throttlingKeyProvider = applicationContext.getBean(keyProvider.getKeyProvider(), ThrottlingKeyProvider.class);

        // 限流唯一 key
        String uniqKey = keyProvider.getPrefix() + ":" + throttlingKeyProvider.generateUniqKey();

        Bucket bucket = localCache.getIfPresent(uniqKey);

        if (bucket == null) {
            bucket = createNewBucket(list);
            localCache.put(uniqKey, bucket);
        }

        // 尝试获取令牌
        ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);

        // 获取不到令牌,直接抛出限流异常
        if (!probe.isConsumed()) {
            throw new ServiceException(StatusCodeEnum.TOO_MANY_REQUESTS, "Too many requests,X-Rate-Limit-Retry-After-Seconds:" + TimeUnit.NANOSECONDS.toSeconds(probe.getNanosToWaitForRefill()) + ",keyProvider:" + keyProvider);
        }

    }

    /**
     * 创建一个限流规则
     *
     * @return
     */
    private Bucket createNewBucket(List<Throttling> list) {

        LocalBucketBuilder builder = Bucket.builder();

        list.forEach(x -> {
            // 每 period 单位时间内最高调用 capacity 次数
            switch (x.timeUnit()) {
                case SECONDS:
                    builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofSeconds(x.period()))));
                    break;
                case MINUTES:
                    builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofMinutes(x.period()))));
                    break;
                case HOURS:
                    builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofHours(x.period()))));
                    break;
                case DAYS:
                    builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofDays(x.period()))));
                    break;
                default:
                    throw new IllegalStateException("Unexpected value: " + x.timeUnit());
            }
        });

        return builder.build();
    }

    @Data
    @Builder
    @AllArgsConstructor
    @NoArgsConstructor
    public static class ThrottlingKey {

        /**
         * 唯一标识,需要实现 ThrottlingKeyProvider 接口
         */
        private Class<? extends ThrottlingKeyProvider> keyProvider;

        /**
         * 缓存 key 的前缀
         * (1)默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流;
         * (2)如果需要要对全局进行限流,需设置固定值,例如想实现全局IP的限流,可以采用固定 prefix + (IP)keyProvider 的模式
         */
        private String prefix;
    }
}

(4)使用方式

  • 实现 ThrottlingKeyProvider 接口,生成限流唯一 key;
/**
 * @author chenyusheng
 * @create 2023/6/28 9:56
 * @description 自定义限流 实现类
 */
@Slf4j
@Component("ipThrottlingKeyProviderImpl")
public class IpThrottlingKeyProviderImpl implements ThrottlingKeyProvider {

    /**
     * HttpServletRequest 请求上下文
     */
    @Resource
    protected HttpServletRequest httpServletRequest;

    /**
     * 生成限流标识
     *
     * @return
     */
    @Override
    public String generateUniqKey() {
        // 获取当前请求的 IP
        String ip = IpUtil.getIpAddress(httpServletRequest);
        if (StringUtils.isBlank(ip)) {
            throw new ServiceException(StatusCodeEnum.SERVICE_UNAVAILABLE, "无法获取当前请求的IP");
        }
        return ip;
    }
}
  • 使用注解 @Throttling(capacity = 2, keyProvider = IpThrottlingKeyProviderImpl.class)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Spring Boot 中的 AOP (Aspect Oriented Programming) 可以用来实现接口级别的限流功能,通常我们会结合第三方库如 Spring Cloud Gateway、Webrx 或者自定义拦截器来完成这个任务。以下是一个简单的概述: 1. 引入依赖:首先,在Spring Boot项目中添加限流库的依赖,比如 Spring Cloud Gateway 提供了 WebFlux 基于令牌桶(Token Bucket)的限流支持。 ```xml <dependency> <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-starter-gateway</artifactId> </dependency> ``` 2. 定义限流规则:在配置类中设置限流策略,例如限制某个接口每秒的请求次数。Spring Cloud Gateway 使用 `RateLimiter` 来控制。 ```java @Bean public RateLimiter myLimit(RateLimiterConfig config) { // 设置限流参数,如每秒50次请求 return RateLimiter.of("my-limit", config.limitForPeriod(1, TimeUnit.SECONDS), false); } ``` 3. AOP 配置:创建一个切面(Aspect),利用 `@Around` 注解和 `RateLimiter` 对目标方法进行拦截,并在调用之前检查是否达到限流阈值。 ```java @Aspect @Component public class ApiRateLimitingAspect { @Autowired private RateLimiter myLimit; @Around("@annotation(api)") public Object limitApi(ProceedingJoinPoint joinPoint, Api api) throws Throwable { if (!myLimit.tryAcquire()) { throw new RateLimiterRejectedException("Exceeded rate limit"); } // 执行原方法 return joinPoint.proceed(); } // 如果你需要为每个接口定义不同的限流规则,可以使用注解来标记 @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface Api { String value() default ""; } } ``` 在这个例子中,我们假设有一个 `Api` 注解用于标记接口,然后在 `limitApi` 方法中对被该注解修饰的方法进行限流
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值