装载至https://gitee.com/xkcoding/spring-boot-demo?_from=gitee_search中的demo-ratelimit-guava模块
Spring Boot 项目如何通过 AOP 结合 Guava 的 RateLimiter 实现限流,旨在保护 API 被恶意频繁访问的问题
部分pom文件
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
1.先定义一个限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
int NOT_LIMITED = 0;
/**
* qps
*/
@AliasFor("qps") double value() default NOT_LIMITED;
/**
* qps
*/
@AliasFor("value") double qps() default NOT_LIMITED;
/**
* 超时时长
*/
int timeout() default 0;
/**
* 超时时间单位
*/
TimeUnit timeUnit() default TimeUnit.MILLISECONDS;
}
然后定义一个切面来在识别到注解RateLimiter时,做出判断给出异常
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
private static final ConcurrentMap<String, com.google.common.util.concurrent.RateLimiter> RATE_LIMITER_CACHE = new ConcurrentHashMap<>();
@Pointcut("@annotation(com.xkcoding.ratelimit.guava.annotation.RateLimiter)")
public void rateLimit() {
}
@Around("rateLimit()")
public Object pointcut(ProceedingJoinPoint point) throws Throwable {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
// 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
if (rateLimiter != null && rateLimiter.qps() > RateLimiter.NOT_LIMITED) {
double qps = rateLimiter.qps();
if (RATE_LIMITER_CACHE.get(method.getName()) == null) {
// 初始化 QPS
RATE_LIMITER_CACHE.put(method.getName(), com.google.common.util.concurrent.RateLimiter.create(qps));
}
log.debug("【{}】的QPS设置为: {}", method.getName(), RATE_LIMITER_CACHE.get(method.getName()).getRate());
// 尝试获取令牌
if (RATE_LIMITER_CACHE.get(method.getName()) != null && !RATE_LIMITER_CACHE.get(method.getName()).tryAcquire(rateLimiter.timeout(), rateLimiter.timeUnit())) {
throw new RuntimeException("手速太快了,慢点儿吧~");
}
}
return point.proceed();
}
}
然后定义一个全局异常拦截,运行时的异常类
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(RuntimeException.class)
public Dict handler(RuntimeException ex) {
return Dict.create().set("msg", ex.getMessage());
}
}
下面是一个测试demoController
@Slf4j
@RestController
public class TestController {
@RateLimiter(value = 1.0, timeout = 300)
@GetMapping("/test1")
public Dict test1() {
log.info("【test1】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
@GetMapping("/test2")
public Dict test2() {
log.info("【test2】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "我一直都在,卟离卟弃");
}
@RateLimiter(value = 2.0, timeout = 300)
@GetMapping("/test3")
public Dict test3() {
log.info("【test3】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
}