pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aspects</artifactId>
<version>5.2.6.RELEASE</version>
</dependency>
</dependencies>
AccessInterceptor自定义注解
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface AccessInterceptor {
/** 用哪个字段作为拦截标识,未配置则默认走全部 */
String key() default "fingerprint";
/** 限制频次(每秒请求次数) */
double permitsPerSecond();
/** 黑名单拦截(多少次限制后加入黑名单)0 不限制 */
double blacklistCount() default 0;
/** 拦截后的执行方法 */
String fallbackMethod();
}
RateLimiterAOP切面织入:
@Slf4j
@Aspect
@Component
public class RateLimiterAOP {
// 个人限频记录1分钟
private final Cache<String, RateLimiter> loginRecord = CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.MINUTES)
.build();
// 个人限频黑名单24h - 自身的分布式业务场景,可以记录到 Redis 中
private final Cache<String, Long> blacklist = CacheBuilder.newBuilder()
.expireAfterWrite(24, TimeUnit.HOURS)
.build();
@Pointcut("@annotation(cn.bugstack.xfg.dev.tech.annotation.AccessInterceptor)")
public void aopPoint() {
}
@Around("aopPoint() && @annotation(accessInterceptor)")
public Object doRouter(ProceedingJoinPoint jp, AccessInterceptor accessInterceptor) throws Throwable {
String key = accessInterceptor.key();
if (StringUtils.isBlank(key)) {
throw new RuntimeException("annotation RateLimiter uId is null!");
}
// 获取拦截字段
String keyAttr = getAttrValue(key, jp.getArgs());
log.info("aop attr {}", keyAttr);
// 黑名单拦截
if (!"all".equals(keyAttr) && accessInterceptor.blacklistCount() != 0 && null != blacklist.getIfPresent(keyAttr) && blacklist.getIfPresent(keyAttr) > accessInterceptor.blacklistCount()) {
log.info("限流-黑名单拦截(24h):{}", keyAttr);
return fallbackMethodResult(jp, accessInterceptor.fallbackMethod());
}
// 获取限流 -> Guava 缓存1分钟
RateLimiter rateLimiter = loginRecord.getIfPresent(keyAttr);
if (null == rateLimiter) {
rateLimiter = RateLimiter.create(accessInterceptor.permitsPerSecond());
loginRecord.put(keyAttr, rateLimiter);
}
// 限流拦截
if (!rateLimiter.tryAcquire()) {
if (accessInterceptor.blacklistCount() != 0) {
if (null == blacklist.getIfPresent(keyAttr)) {
blacklist.put(keyAttr, 1L);
} else {
blacklist.put(keyAttr, blacklist.getIfPresent(keyAttr) + 1L);
}
}
log.info("限流-超频次拦截:{}", keyAttr);
return fallbackMethodResult(jp, accessInterceptor.fallbackMethod());
}
// 返回结果
return jp.proceed();
}
/**
* 调用用户配置的回调方法,当拦截后,返回回调结果。
*/
private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
Method method = jp.getTarget().getClass().getMethod(fallbackMethod, methodSignature.getParameterTypes());
return method.invoke(jp.getThis(), jp.getArgs());
}
private Method getMethod(JoinPoint jp) throws NoSuchMethodException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
return jp.getTarget().getClass().getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
}
/**
* 实际根据自身业务调整,主要是为了获取通过某个值做拦截
*/
public String getAttrValue(String attr, Object[] args) {
if (args[0] instanceof String) {
return args[0].toString();
}
String filedValue = null;
for (Object arg : args) {
try {
if (StringUtils.isNotBlank(filedValue)) {
break;
}
// filedValue = BeanUtils.getProperty(arg, attr);
// fix: 使用lombok时,uId这种字段的get方法与idea生成的get方法不同,会导致获取不到属性值,改成反射获取解决
filedValue = String.valueOf(this.getValueByName(arg, attr));
} catch (Exception e) {
log.error("获取路由属性值失败 attr:{}", attr, e);
}
}
return filedValue;
}
/**
* 获取对象的特定属性值
*
* @param item 对象
* @param name 属性名
* @return 属性值
* @author tang
*/
private Object getValueByName(Object item, String name) {
try {
Field field = getFieldByName(item, name);
if (field == null) {
return null;
}
field.setAccessible(true);
Object o = field.get(item);
field.setAccessible(false);
return o;
} catch (IllegalAccessException e) {
return null;
}
}
/**
* 根据名称获取方法,该方法同时兼顾继承类获取父类的属性
*
* @param item 对象
* @param name 属性名
* @return 该属性对应方法
* @author tang
*/
private Field getFieldByName(Object item, String name) {
try {
Field field;
try {
field = item.getClass().getDeclaredField(name);
} catch (NoSuchFieldException e) {
field = item.getClass().getSuperclass().getDeclaredField(name);
}
return field;
} catch (NoSuchFieldException e) {
return null;
}
}
配置类:
@Configuration
@ComponentScan("cn.bugstack.xfg.dev.tech")
public class GuavaCacheConfig {
@Bean(name = "codeCache")
public Cache<String, String> codeCache() {
return CacheBuilder.newBuilder()
.expireAfterWrite(3, TimeUnit.MINUTES)
.build();
}
@Bean
public RateLimiterAOP rateLimiter(){
return new RateLimiterAOP();
}
}
Login测试:
@Slf4j
@RestController()
@RequestMapping("/api/ratelimiter/")
//http://localhost:8091/api/ratelimiter/login?fingerprint=uljpplllll01009&uId=1000&token=8790
//调用频次即可实现限流
public class Login {
@AccessInterceptor(key = "fingerprint", fallbackMethod = "loginErr", permitsPerSecond = 1.0d, blacklistCount = 10)
@RequestMapping(value = "login", method = RequestMethod.GET)
public String login(String fingerprint, String uId, String token) {
log.info("模拟登录 fingerprint:{}", fingerprint);
return "模拟登录:登录成功 " + uId;
}
}