自定义注解
/**
* 速率限制注解
*
* @author: 张定辉
* @date: 2024/3/5 21:29
* @description: 速率限制注解
*/
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* SPEL表达式
* <p>
* 1.使用方法的基本类型参数作为限流Key
* <p>
* @RateLimit(value="#id")
* public void test(String id){}
* <p><p>
* 2.使用方法的对象类型参数中的某个属性作为限流Key
* <p>
* @RateLimit(value="#user.username")
* public void test(User user){}
* <p><p>
* 3.将方法参数作为bean方法的参数并获取返回值作为限流Key,暂时只支持bean的方法是String类型
* <p>
* @Service(value="parseBean")<p>
* public class ParseBean{<p>
* public String parse(String arg){<p>
* return arg+"limitKey";<p>
* }<p>
* }<p>
*<p>
* @RateLimit(value="@parseBean.parse(username)")<p>
* public void test(String username){}
*/
String value();
/**
* 限流间隔,以秒为单位
*/
int interval()default 3;
/**
* 单位之间内的速率限制
*/
int frequency()default 20;
}
SPEL配置类
/**
* Spel表达式配置类
*
* @author: 张定辉
* @date: 2024/3/7 14:20
* @description: Spel表达式配置类
*/
@Configuration
public class SpelConfig {
@Bean
public StandardEvaluationContext evaluationContext(ApplicationContext applicationContext) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.addPropertyAccessor(new BeanFactoryAccessor());
context.setBeanResolver(new BeanFactoryResolver(applicationContext));
context.setTypeLocator(new StandardTypeLocator(applicationContext.getClassLoader()));
context.setTypeConverter(new StandardTypeConverter());
return context;
}
}
AOP切面
/**
* 速率限制注解处理器
*
* @author: 张定辉
* @date: 2024/3/5 21:37
* @description: 速率限制注解处理器
*/
@Aspect
@Component
@RequiredArgsConstructor
public class RateLimitHandler {
private final ApplicationContext applicationContext;
private final SpelExpressionParser parser = new SpelExpressionParser();
private final StandardEvaluationContext context;
private final RedisTemplate<String, Object> redisTemplate;
@SneakyThrows
@Around("@within(com.ai.common.annotation.RateLimit) || @annotation(com.ai.common.annotation.RateLimit)")
public Object handler(ProceedingJoinPoint joinPoint) {
Object target = joinPoint.getTarget();
String spelValue;
int interval;
int frequency;
//如果注解是标注在类上
if (target.getClass().isAnnotationPresent(RateLimit.class)) {
Class<?> aClass = target.getClass();
RateLimit annotation = aClass.getAnnotation(RateLimit.class);
spelValue = annotation.value();
interval = annotation.interval();
frequency = annotation.frequency();
if (spelValue.startsWith("@")) {
addBeanResultToContext(context, spelValue);
}
}
//注解标注在方法上
else {
Object[] args = joinPoint.getArgs();
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
RateLimit rateLimit = method.getAnnotation(RateLimit.class);
interval = rateLimit.interval();
frequency = rateLimit.frequency();
spelValue = rateLimit.value();
String[] parameterNames = signature.getParameterNames();
for (int i = 0; i < args.length; i++) {
//这行代码在后续的使用bean的方法返回值作为KEY限流时有用处
context.setVariable(parameterNames[i], args[i]);
if (args[i] != null && !isPrimitive(args[i].getClass())) {
addObjectPropertiesToContext(context, parameterNames[i], args[i]);
}
}
if (spelValue.startsWith("@")) {
spelValue = addBeanResultToContext(context, spelValue);
}
}
Expression expression = parser.parseExpression(spelValue);
Object key = expression.getValue(context);
//使用Redis进行限流
redisRateLimit(JSON.toJSONString(key), interval, frequency);
return joinPoint.proceed();
}
/**
* 添加对象属性值到SPEL上下文环境中
*/
@SneakyThrows
private void addObjectPropertiesToContext(StandardEvaluationContext context, String paramName, Object arg) {
Class<?> clazz = arg.getClass();
Method[] methods = clazz.getMethods();
for (Method method : methods) {
String methodName = method.getName();
if (methodName.startsWith("get") && !methodName.equals("getClass")) {
String propertyName = methodName.substring(3, 4).toLowerCase() + methodName.substring(4);
Object propertyValue = method.invoke(arg);
context.setVariable(paramName + "." + propertyName, propertyValue);
}
}
}
/**
* 将Bean方法的执行结果设置到SPEL上下文环境中
*/
@SneakyThrows
private String addBeanResultToContext(StandardEvaluationContext context, String spelValue) {
Object bean = applicationContext.getBean(spelValue.substring(1, spelValue.indexOf(".")));
String methodName = spelValue.substring(spelValue.indexOf(".") + 1, spelValue.indexOf("("));
String[] methodArgs = spelValue.substring(spelValue.indexOf("(") + 1, spelValue.indexOf(")")).split(",");
Object[] methodArgsValues = new Object[methodArgs.length];
for (int i = 0; i < methodArgs.length; i++) {
methodArgsValues[i] = context.lookupVariable(methodArgs[i]);
if (Objects.isNull(methodArgsValues[i])) {
methodArgsValues[i] = methodArgs[i];
}
}
Class<?>[] argumentsTypes = getArgumentsTypes(methodArgsValues);
boolean b = Arrays.stream(argumentsTypes).allMatch(Objects::isNull);
Method beanMethod = bean.getClass().getMethod(methodName, b?new Class<?>[0]:argumentsTypes);
Object beanMethodResult = beanMethod.invoke(bean, b?null:methodArgsValues);
context.setVariable("beanMethodResult", beanMethodResult);
return "#beanMethodResult";
}
/**
* 获取参数的类型
*/
private Class<?>[] getArgumentsTypes(Object[] args) {
Class<?>[] types = new Class<?>[args.length];
for (int i = 0; i < args.length; i++) {
Class<?> aClass = args[i].getClass();
if (aClass.isAssignableFrom(String.class)) {
String arg = (String) args[i];
types[i] = StringUtils.isBlank(arg) ? null : aClass;
} else {
types[i] = aClass;
}
}
return types;
}
/**
* 判断是否是基础数据类型
*/
private boolean isPrimitive(Class<?> clazz) {
return clazz.isPrimitive() || clazz == String.class || clazz == Integer.class
|| clazz == Long.class || clazz == Double.class || clazz == Float.class
|| clazz == Boolean.class || clazz == Character.class || clazz == Short.class
|| clazz == Byte.class;
}
/**
* 结合Redis进行限流操作
*/
private void redisRateLimit(String key, int interval, int frequency) throws OperationsException {
long l = execLua(key, interval);
if (l > frequency) {
throw new OperationsException("操作过于频繁,请稍后再试!");
}
}
/**
* 使用Lua脚本执行原子性的Redis操作,
* 如果key不存在则设置value为1并且设置过期时间为5秒,
* 如果key存在则进行累加。避免多线程并发时,由于key被修改过导致设置过期时间时失败从而导致key永不失效
*
* @return 如果没有key则返回1,如果有key则返回累加后的value
*/
private long execLua(String key, int expireTime) {
String luaScript = """
if redis.call('exists', KEYS[1]) == 0 then
redis.call('set', KEYS[1], 1, 'ex', %s)
return 1
else
return redis.call('incr',KEYS[1])
end
""".formatted(expireTime);
RedisScript<Long> script = new DefaultRedisScript<>(luaScript, Long.class);
Long result = redisTemplate.execute(script, Collections.singletonList(key));
return Objects.isNull(result) ? 0 : result;
}
}
定义Bean方法解析的业务类
该业务类主要是为了满足在使用自定义注解时我们会使用某个类的方法的返回值作为限流Key,这个类自己自定义即可,这里只是做简单的演示使用
/**
* @author: 张定辉
* @date: 2024/3/7 11:48
* @description: 使用方法返回值作为限流Key的业务方法
*/
@Service(value = "parseService")
public class ParseService {
public String parse(String param){
return param+"yyds";
}
public String parse2(){
return "yyds";
}
}
实际应用
注解标注在接口方法上,使用方法参数作为限流Key
5秒内只能访问两次该接口
@GetMapping("/test") @RateLimit(value = "#id",interval = 5,frequency = 2) public Res<Object> test(@RequestParam String id){ return Res.success(); }
标注在接口方法上,使用对象的属性值作为限流Key
5秒内只能访问两次该接口
@PostMapping("/test") @RateLimit(value = "#user.username",interval = 5,frequency = 2) public Res<Object> test(@RequestBody User user){ return Res.success(); }
标注在接口方法上,使用 parseService 业务类型的 parse方法返回值作为限流Key
5秒内只能访问两次该接口
@GetMapping("/test") @RateLimit(value = "@parseService.parse(id)",interval = 5,frequency = 2) public Res<Object> test(@RequestParam String id){ return Res.success(); }
标注在接口类下,实现该接口下的所有接口方法都限流
/** * @author: 张定辉 * @date: 2024/3/7 14:14 * @description: */ @RequestMapping("/test") @RestController @RateLimit(value = "@parseService.parse2()",interval = 5,frequency = 2) public class Text2Controller { @GetMapping("/test1") public Res<Object> test1(){ return Res.success(); } @GetMapping("/test2") public Res<Object> test2(){ return Res.success(); } }
写的可能不是很完善,如果有大佬能够指正的话不甚感激