一、切面
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
private RedisTemplate<Object,Object> redisTemplate;
private RedisScript<Long> limitScript;
@Autowired
public void setRedisTemplate(RedisTemplate<Object, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Autowired
public void setLimitScript(RedisScript<Long> limitScript) {
this.limitScript = limitScript;
}
@Before("@annotation(rateLimiter)")
private void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable{
int time = rateLimiter.time();
int count = rateLimiter.count();
String key = getKey(rateLimiter,point);
//集合工具类
List<Object> keys = Collections.singletonList(key);
try {
Long number = redisTemplate.execute(limitScript,keys,count,time);
if (number == null || number.intValue() > count){
throw new ServiceException("访问过于频繁,请稍后再试");
}
log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
}catch (ServiceException e){
throw e;
}catch (Exception e){
throw new RuntimeException("服务器限流异常,请稍后再试~");
}
}
public String getKey(RateLimiter rateLimiter,JoinPoint joinPoint){
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
//TODO:获取客户端真实ip地址
}
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
log.info("可以看一下:" + stringBuffer);
return stringBuffer.toString();
}
}
二、注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
public String key() default "rate_limit:";
public int time() default 60;
public int count() default 100;
public LimitType limitType() default LimitType.DEFAULT;
}
public enum LimitType {
DEFAULT,
IP
}
三、关键脚本解读
-- Lua 变量有三种类型:全局变量、局部变量(Local属于局部变量)、表中的域。
-- 参数数组(这里需要注意:redis中下标是从1开始的)
local key = KEYS[1]
-- 可变参数数组
-- to number,这样是不是就好理解了?转换成数字
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key);
if current and tonumber(current) > count then
return tonumber(current);
end
--
current = redis.call('incr', key)
-- 设定key的过期时间
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current);
方法参数:脚本,key 对应 KEYS[1] ,可变参数 对应 ARGV[1]
Redis配置类
@Override
public <T> T execute(RedisScript<T> script, List<K> keys, Object... args) {
return scriptExecutor.execute(script, keys, args);
}
@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport {
@Bean
public RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory connectionFactory){
RedisTemplate<Object,Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
FastJson2JsonRedisSerializer serializer = new FastJson2JsonRedisSerializer(Object.class);
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(serializer);
template.setHashKeySerializer(new StringRedisSerializer());
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
@Bean
public DefaultRedisScript<Long> limitScript(){
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(limitScriptText());
redisScript.setResultType(Long.class);
return redisScript;
}
//限流脚本
private String limitScriptText(){
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
}
四、使用
@RequestMapping("/test/limit")
@RestController
public class TestRateLimitController {
@RateLimiter(time = 1,count = 5)
@RequestMapping("/get")
public String getUserInfo(String id) throws Exception{
Thread.sleep(200);
System.out.println("/test/limit/get");
return "ok";
}
}
总结:
利用切面编程和redis lua脚本实现了限流