Google 开源工具包 Guava 提供了限流工具类 RateLimiter,基于令牌桶算法实现流量限制,使用十分方便。
1.maven中引入依赖
<!-- Spring Boot Starter for AOP -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- 限流 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>30.1-jre</version>
</dependency>
2.新建一个切面类
import com.cx.sasmerp.limitflow.RateLimit;
import com.google.common.util.concurrent.RateLimiter;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* <p>
* 限流切面类
* </p>
*
* @author Lch
* @dateTime 2024/3/5 10:11
*/
@Component
@Scope
@Aspect
public class RateLimitAspect {
/**
* 存储限流量和方法必须是static且线程安全
*/
public static Map<String, RateLimiter> rateLimitMap = new ConcurrentHashMap<>();
/**
* 业务层切点
*/
@Pointcut("@annotation(com.cx.sasmerp.limitflow.RateLimit)")
public void ServiceAspect() {
}
@Around("ServiceAspect()")
public Object around(ProceedingJoinPoint joinPoint) {
Object obj = null;
try {
//获取目标对象
Class<?> clz = joinPoint.getTarget().getClass();
//tryAcquire()是非阻塞, rateLimiter.acquire()是阻塞的
Signature signature = joinPoint.getSignature();
String name = signature.getName();
String limitKey = getLimitKey(clz, name);
RateLimiter rateLimiter = rateLimitMap.get(limitKey);
if (rateLimiter.tryAcquire()) {
obj = joinPoint.proceed();
} else {
//拒绝了请求(服务降级)
obj = "The system is busy, please visit after a while";
}
} catch (Throwable e) {
e.printStackTrace();
}
return obj;
}
private String getLimitKey(Class<?> clz, String methodName) {
for (Method method : clz.getDeclaredMethods()) {
//找出目标方法
if (method.getName().equals(methodName)) {
//判断是否是限流方法
if (method.isAnnotationPresent(RateLimit.class)) {
String key= method.getAnnotation(RateLimit.class).limitKey();
if(key.equals("")){
key=method.getName();
}
return key;
}
}
}
return null;
}
}
3.新建自定义注解接口
import java.lang.annotation.*;
/**
* <p>
* 限流自定义注解
* </p>
*
* @author Lch
* @dateTime 2024/3/5 10:09
*/
@Target({ElementType.PARAMETER, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimit {
String limitKey() default ""; //限流的方法名
double value() default 0d; //发放的许可证数量
}
4.新建一个类 初始化限流的许可证数量
import com.cx.sasmerp.aspect.RateLimitAspect;
import com.google.common.util.concurrent.RateLimiter;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.RestController;
import java.lang.reflect.Method;
import java.util.Map;
/**
* <p>
* 初始化限流的许可证数量
* </p>
*
* @author Lch
* @dateTime 2024/3/5 10:15
*/
@Component
public class InitRateLimit implements ApplicationContextAware {
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
Map<String, Object> beanMap = applicationContext.getBeansWithAnnotation(RestController.class);
beanMap.forEach((k, v) -> {
Class<?> controllerClass = v.getClass();
System.out.println(controllerClass.toString());
System.out.println(controllerClass.getSuperclass().toString());
//获取所有声明的方法
Method[] allMethods = controllerClass.getSuperclass().getDeclaredMethods();
for (Method method : allMethods) {
//判断方法是否使用了限流注解
if (method.isAnnotationPresent(RateLimit.class)) {
//获取配置的限流量,实际值可以动态获取,配置key,根据key从配置文件获取
double value = method.getAnnotation(RateLimit.class).value();
String key = method.getAnnotation(RateLimit.class).limitKey();
if(key.equals("")){
key=method.getName();
}
System.out.println("RatelimitKey:" +key+",许可证数是:"+value);
//key作为key.value为具体限流量,传递到切面的map中
RateLimitAspect.rateLimitMap.put(key, RateLimiter.create(value));
}
}
});
}
}
5.使用
@RateLimit(value = 5)
@PostMapping("/statisticsExportData")