思想:使用aop,在有注解@Ratelimiter(注解中的变量包括时间窗口,限制访问次数)的方法中执行一个前置通知,在执行方法前 执行redis的lua脚本去redis中查询访问次数(key值为访问方法名称,value值为访问次数,如果是第一次访问,将时间窗口时间设置为该键值对的过期时间),在时间窗口内超过访问次数,抛出自定义异常交给全局处理器中对应的方法处理,返回相应的json。
1.自定义注解
public @interface RateLimiter {
/**
* 限流的前缀
* @return
*/
String key() default "rate_limit";
/**
* 限流时间窗
*/
int time() default 60;
/**
* 在时间窗内的限流次数
* @return
*/
int count() default 100;
LimitType limiType() default LimitType.DEFALUT;
}
2.自定义切面
@Autowired
RedisTemplate<Object,Object> redisTemplate;//注入自定义的redistemplate
@Autowired
RedisScript<Long> redisScript;//注入自定义的redisScript
@Before("@annotation(rateLimiter)")
public void before(JoinPoint joinPoint, RateLimiter rateLimiter){
int count = rateLimiter.count();//最大访问次数
int time = rateLimiter.time();//时间窗口
String combineKey = getCombineKey(rateLimiter,joinPoint);//调用次数在redis中的key(前缀加接口名字(可选ip))
try {
Long number = redisTemplate.execute(redisScript, Collections.singletonList(combineKey),time,count);//执行自定义的lua脚本保证原子性
if(number == null || number.intValue() > count){
//超过限流
logger.info("当前接口已经达到最大限流次数");
throw new RateLimitException("访问过于频繁,请稍后访问");
}
logger.info("一个时间窗{}秒内请求次数是多少:{},当前请求次数:{},缓存的key为{}",time,count,number,combineKey);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public String getCombineKey(RateLimiter rateLimiter,JoinPoint joinPoint){
StringBuffer key = new StringBuffer(rateLimiter.key());
if(rateLimiter.limiType()== LimitType.IP){
//key.append(IpUtils.getIpAddress(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest()));
}
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
key.append(":").append(method.getDeclaringClass().getName()).append('-').append(method.getName());
return key.toString();
}
3.自定义异常与全局异常处理
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(RateLimitException.class)
public Map<String,Object> rateLimitException(RateLimitException e){
Map<String,Object> map = new HashMap<>();
map.put("status",500);
map.put("message",e.getMessage());
return map;
}
}
public class RateLimitException extends Exception{
public RateLimitException(String messge){
super(messge);
}
}
4.自定义lua脚本
local key=KEYS[1]--将传进来的KEYS数组的第一项赋值给这个key变量
local time=tonumber(ARGV[1])
local count=tonumber(ARGV[2])
local current=redis.call('get',key)
if current and tonumber(current)>count then --大于限制访问次数直接返回
return tonumber(current)
end
current = redis.call('incr',key) --第一次访问接口设置一个限制时间
if tonumber(current)==1 then
redis.call('expire',key,time)
end
return tonumber(current); --返回正常访问次数
5.自定义redis config
@Configuration
public class RedisConfig {
//区别在于1能操作的数据类型可以是对象,2只能是字符串
//自定义redisTemplate 不用jdk的序列化方案 防止加前缀 (加前缀之后不能使用redis命令直接获取)使用json的序列化方案
// RedisTemplate redisTemplate;
// StringRedisTemplate stringRedisTemplate;
@Bean
RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory redisConnectionFactory){
RedisTemplate<Object,Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
template.setHashKeySerializer(serializer);
template.setKeySerializer(serializer);
template.setHashValueSerializer(serializer);
template.setValueSerializer(serializer);
return template;
}
@Bean
DefaultRedisScript<Long> limitScript(){
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
}
测试
@RestController
public class HelloController {
@RateLimiter(time = 10,count = 3)//10秒之内这个接口可以访问3次
@GetMapping("/hello")
public String hello(){
return "hello";
}
@GetMapping("/baby")
public String baby(){
return "baby";
}
}