SpringBoot 使用 Google 缓存框架 + bucket4j 限流框架实现的基于内存的单机限流,分布式限流可以在网关层使用。
(1)引入依赖
<dependencies>
<!-- Google 缓存框架 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>31.1-jre</version>
</dependency>
<!-- 使用 SpringBoot Aop 框架 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- bucket4j 限流框架 -->
<dependency>
<groupId>com.github.vladimir-bukhtoyarov</groupId>
<artifactId>bucket4j-core</artifactId>
<version>7.6.0</version>
</dependency>
<!-- 注入常用工具类 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.18</version>
</dependency>
</dependencies>
(2)定义 @Throttling 和 @Throttlings 接口限流注解
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description 接口限流注解
* 注:可以用在方法和类上,可以同时存在多个
*/
@Inherited
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Throttling {
/**
* 生成的令牌数量
*/
int capacity() default 5;
/**
* 每单位时间,生成的令牌数量,默认:1 秒内生成 5 个令牌
*/
long period() default 1;
/**
* 时间单位,默认:秒
*/
TimeUnit timeUnit() default TimeUnit.SECONDS;
/**
* 唯一标识,需要实现 ThrottlingKeyProvider 接口
*/
Class<? extends ThrottlingKeyProvider> keyProvider();
/**
* 缓存 key 的前缀
* (1)默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流;
* (2)如果需要要对全局进行限流,需设置固定值,例如想实现全局IP的限流,可以采用固定 prefix + (IP)keyProvider 的模式,相当于提供某一组接口共用一个限流配置
*/
String prefix() default "";
}
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description 接口限流注解
*/
@Inherited
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Throttlings {
Throttling[] value();
}
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description 限流唯一 key 生成接口
*/
public interface ThrottlingKeyProvider {
/**
* 生成限流标识
*
* @return
*/
String generateUniqKey();
}
(3)定义 ThrottlingAop 限流 AOP
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description AOP切面 抽象基类
*/
public abstract class AbstractAop {
/**
* 获取 Method 的指定注解
*
* @param joinPoint
* @return T extends Annotation
*/
protected <T extends Annotation> T getMethodAnnotation(ProceedingJoinPoint joinPoint, Class<T> clazz) {
T annotation = null;
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Class<?>[] paramTypeArray = methodSignature.getParameterTypes();
Method transferMoney = null;
try {
transferMoney = joinPoint.getTarget().getClass().getDeclaredMethod(methodSignature.getName(), paramTypeArray);
} catch (NoSuchMethodException e) {
e.printStackTrace();
}
assert transferMoney != null;
boolean annotationPresent = transferMoney.isAnnotationPresent(clazz);
if (annotationPresent) {
annotation = transferMoney.getAnnotation(clazz);
}
return annotation;
}
/**
* 获取当前类及其父类包含的指定的注解类型
*
* @param clazz
* @param annotation
* @param <T>
* @return
*/
protected <T extends Annotation> List<T> getInheritClassAnnotation(Class<?> clazz, Class<T> annotation) {
List<T> list = new ArrayList<>();
// 获取当前类及其父类包含的指定的T类型注解
do {
T annot = getClassAnnotation(clazz, annotation);
if (annot != null) {
list.add(annot);
}
clazz = clazz.getSuperclass();
} while (clazz != null);
return list;
}
/**
* 获取指定类的注解
*
* @param clazz
* @param annotation
* @param <T>
* @return
*/
protected <T extends Annotation> T getClassAnnotation(Class<?> clazz, Class<T> annotation) {
T res = null;
boolean annotationPresent = clazz.isAnnotationPresent(annotation);
if (annotationPresent) {
res = clazz.getAnnotation(annotation);
}
return res;
}
}
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description 限流 AOP
* 注意:基于内存保存数据,缓存的数据量不大,且仅仅支持单体应用限流.
*/
@Aspect
@Component
@Order(1)
public class ThrottlingAop extends AbstractAop {
/**
* 基于内存保存数据,缓存的数据量不大,且仅仅支持单体应用.
*/
public static Cache<String, Bucket> localCache = CacheBuilder.newBuilder()
// 指定缓存项在给定时间内没有被访问,则回收
.expireAfterAccess(1, TimeUnit.DAYS)
// 设置缓存容器的初始容量
.initialCapacity(200)
// 设置缓存最大容量,超过之后就会按照LRU最近最少使用算法来移除缓存项
.maximumSize(10000)
// 设置并发级别为8,并发级别是指可以同时写缓存的线程数
.concurrencyLevel(8).build();
@Resource
private ApplicationContext applicationContext;
/**
* 基于方法上的注解
*/
@Pointcut("@annotation(cn.yifants.common.extension.aop.throttling.Throttling) || @annotation(cn.yifants.common.extension.aop.throttling.Throttlings)")
public void methodAnnotation() {
}
/**
* 基于类上的注解
*/
@Pointcut("@within(cn.yifants.common.extension.aop.throttling.Throttling) || @within(cn.yifants.common.extension.aop.throttling.Throttlings)")
public void classAnnotation() {
}
@Around("methodAnnotation() || classAnnotation()")
public Object rounding(ProceedingJoinPoint joinPoint) throws Throwable {
List<Throttling> annotations = new ArrayList<>();
// 解析 Throttling 注解
Throttling throttling = getMethodAnnotation(joinPoint, Throttling.class);
if (throttling != null) {
annotations.add(throttling);
}
annotations.addAll(getInheritClassAnnotation(joinPoint.getTarget().getClass(), Throttling.class));
// 解析 Throttlings 注解
Throttlings throttlings = getMethodAnnotation(joinPoint, Throttlings.class);
if (throttlings != null) {
annotations.addAll(Arrays.asList(throttlings.value()));
}
getInheritClassAnnotation(joinPoint.getTarget().getClass(), Throttlings.class).forEach(x -> {
annotations.addAll(Arrays.asList(x.value()));
});
// 限流处理
if (CollectionUtil.isNotEmpty(annotations)) {
// 没有指定 prefix 时,默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流
String signature = SecureUtil.md5(joinPoint.getSignature().toString());
annotations.stream().distinct().collect(Collectors.groupingBy(x -> ThrottlingKey.builder().keyProvider(x.keyProvider()).prefix(x.prefix()).build())).forEach((key, val) -> {
if (StringUtils.isBlank(key.getPrefix())) {
key.setPrefix(signature);
}
Throttlinghandle(key, val);
});
}
Object[] args = joinPoint.getArgs();
return joinPoint.proceed(args);
}
/**
* 限流处理
*
* @param keyProvider
* @param list
*/
private synchronized void Throttlinghandle(ThrottlingKey keyProvider, List<Throttling> list) {
ThrottlingKeyProvider throttlingKeyProvider = applicationContext.getBean(keyProvider.getKeyProvider(), ThrottlingKeyProvider.class);
// 限流唯一 key
String uniqKey = keyProvider.getPrefix() + ":" + throttlingKeyProvider.generateUniqKey();
Bucket bucket = localCache.getIfPresent(uniqKey);
if (bucket == null) {
bucket = createNewBucket(list);
localCache.put(uniqKey, bucket);
}
// 尝试获取令牌
ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
// 获取不到令牌,直接抛出限流异常
if (!probe.isConsumed()) {
throw new ServiceException(StatusCodeEnum.TOO_MANY_REQUESTS, "Too many requests,X-Rate-Limit-Retry-After-Seconds:" + TimeUnit.NANOSECONDS.toSeconds(probe.getNanosToWaitForRefill()) + ",keyProvider:" + keyProvider);
}
}
/**
* 创建一个限流规则
*
* @return
*/
private Bucket createNewBucket(List<Throttling> list) {
LocalBucketBuilder builder = Bucket.builder();
list.forEach(x -> {
// 每 period 单位时间内最高调用 capacity 次数
switch (x.timeUnit()) {
case SECONDS:
builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofSeconds(x.period()))));
break;
case MINUTES:
builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofMinutes(x.period()))));
break;
case HOURS:
builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofHours(x.period()))));
break;
case DAYS:
builder.addLimit(Bandwidth.classic(x.capacity(), Refill.intervally(x.capacity(), Duration.ofDays(x.period()))));
break;
default:
throw new IllegalStateException("Unexpected value: " + x.timeUnit());
}
});
return builder.build();
}
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public static class ThrottlingKey {
/**
* 唯一标识,需要实现 ThrottlingKeyProvider 接口
*/
private Class<? extends ThrottlingKeyProvider> keyProvider;
/**
* 缓存 key 的前缀
* (1)默认使用方法的签名的 MD5,作为缓存 key 的前缀,用于区分是为了给哪个方法设置限流;
* (2)如果需要要对全局进行限流,需设置固定值,例如想实现全局IP的限流,可以采用固定 prefix + (IP)keyProvider 的模式
*/
private String prefix;
}
}
(4)使用方式
- 实现 ThrottlingKeyProvider 接口,生成限流唯一 key;
/**
* @author chenyusheng
* @create 2023/6/28 9:56
* @description 自定义限流 实现类
*/
@Slf4j
@Component("ipThrottlingKeyProviderImpl")
public class IpThrottlingKeyProviderImpl implements ThrottlingKeyProvider {
/**
* HttpServletRequest 请求上下文
*/
@Resource
protected HttpServletRequest httpServletRequest;
/**
* 生成限流标识
*
* @return
*/
@Override
public String generateUniqKey() {
// 获取当前请求的 IP
String ip = IpUtil.getIpAddress(httpServletRequest);
if (StringUtils.isBlank(ip)) {
throw new ServiceException(StatusCodeEnum.SERVICE_UNAVAILABLE, "无法获取当前请求的IP");
}
return ip;
}
}
- 使用注解 @Throttling(capacity = 2, keyProvider = IpThrottlingKeyProviderImpl.class)