基于AOP + guava 实现一个限流拦截
定义限流标识注解
其中key代表限流的字段例如 userId.
permitPerSecond代表每秒最大访问量
blackListCount() 代表违反限流几次进入黑名单
fallbackmethod() 代表限流和黑名单之后的降级方法
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
@Documented
public @interface RateLimiterAccessInterceptor {
String key() default "all";
double permitPerSecond();
double blacklistCount() default 0;
String fallbackMethod();
}
定义切面
其中 loginRecord 记录了标识字段key对应的限流RateLimiter.
当违规一次时 blackList中的数值 + 1
如果违反了注解中定义的黑名单次数值则直接走fallback()降级的方法.
@Aspect
@Slf4j
@Component
public class RateLimiterAOP {
// 限流
private final Cache<String,RateLimiter> loginRecord = CacheBuilder.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES).build();
// 存放黑名单记录
private final Cache<String,Long> blackList = CacheBuilder.newBuilder()
.expireAfterAccess(24, TimeUnit.HOURS).build();
@Pointcut("@annotation(cn.bugstack.types.annotations.RateLimiterAccessInterceptor)")
public void aopPoint(){
}
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
public Object doRouter(ProceedingJoinPoint jp,RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {
String key = rateLimiterAccessInterceptor.key();
if(StringUtils.isBlank(key)){
throw new RuntimeException("uid is null");
}
String keyAttr = getAttrValue(key, jp.getArgs());
log.info("aop attr: {}",keyAttr);
// 如果当前用户违规次数超过了我规定的次数,那么直接降级.
if(!"all".equals(keyAttr) && rateLimiterAccessInterceptor.blacklistCount() != 0 && null != blackList.getIfPresent(keyAttr) && blackList.getIfPresent(keyAttr) > rateLimiterAccessInterceptor.blacklistCount()){
log.info("限流-超频次拦截: {}",keyAttr);
return fallbackMethodResult(jp,rateLimiterAccessInterceptor.fallbackMethod());
}
// 限流 基于用户id
RateLimiter rateLimiter = loginRecord.getIfPresent(keyAttr);
if(null == rateLimiter){
rateLimiter = RateLimiter.create(rateLimiterAccessInterceptor.permitPerSecond());
loginRecord.put(keyAttr,rateLimiter);
}
// 如果超时调用一次就在blackList中增加一次违规次数.
if(!rateLimiter.tryAcquire()){
if(rateLimiterAccessInterceptor.blacklistCount() != 0){
if(null == blackList.getIfPresent(keyAttr)){
blackList.put(keyAttr,1L);
}else{
blackList.put(keyAttr,blackList.getIfPresent(keyAttr) + 1L);
}
}
log.info("限流-超频次拦截: {}",keyAttr);
return fallbackMethodResult(jp,rateLimiterAccessInterceptor.fallbackMethod());
}
return jp.proceed();
}
// 触发fallbackMethod.
private Object fallbackMethodResult(JoinPoint jp,String fallbackMethod) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
Method method = jp.getTarget().getClass().getMethod(fallbackMethod,methodSignature.getParameterTypes());
return method.invoke(jp.getThis(),jp.getArgs());
}
// 从属性中获取值.
public String getAttrValue(String attr,Object[] args){
if(args[0] instanceof String){
return args[0].toString();
}
String fieldValue = null;
for(Object arg : args){
try {
if(StringUtils.isNotBlank(fieldValue)) break;
fieldValue = String.valueOf(this.getValueByName(arg,attr));
}catch (Exception e){
log.error("获取属性失败 attr: {}",attr,e);
}
}
return fieldValue;
}
// 从field中获取值
private Object getValueByName(Object item,String name){
try {
Field field = getFieldByName(item,name);
if(field == null) return null;
field.setAccessible(true);
Object o = field.get(item);
field.setAccessible(false);
return o;
}catch (IllegalAccessException e){
return null;
}
}
// 从字段中获取name匹配的field
private Field getFieldByName(Object item,String name){
try {
Field field;
try {
field = item.getClass().getDeclaredField(name);
}catch (NoSuchFieldException e){
field = item.getClass().getSuperclass().getDeclaredField(name);
}
return field;
}catch (NoSuchFieldException e){
return null;
}
}
}