@Resource
private RedisTemplate<String, String> redisTemplate;
public static final String RATE_LIMIT_KEY = "rateLimit:";
@PostMapping("/test")
public RestResponse test(){
if(checkLimit()){
return RestResponse.ok();
}else {
return RestResponse.error("接口限流");
}
}
private boolean checkLimit(){
//这里可以自定义 某个接口路径,加ip等等,可以改造成一个注解
String key = "openapi";
//限流单位时间(单位为s)
int time = 10;
//单位时间内限制的访问次数
int count = 1;
//拼接 redis中的key
StringBuilder sb = new StringBuilder();
sb.append(RATE_LIMIT_KEY).append(key).append(":");
List<String> keys = Collections.singletonList(sb.toString());
//执行lua脚本
String script = luaString();
RedisScript<Long> redisScript = new DefaultRedisScript(script, Long.class);
RedisSerializer<String> stringSerializer = new StringRedisSerializer();
redisTemplate.setKeySerializer(stringSerializer);
redisTemplate.setValueSerializer(stringSerializer);
Long result = redisTemplate.execute(redisScript, keys, time+"", count+"");
if (result != null && -1 == result) {
log.info("当前接口调用超过时间段内限流,key:{}", key);
return false;
} else {
log.info("当前访问时间段内剩余{}次访问次数", result);
}
return true;
}
private String luaString(){
String lua = "redis.replicate_commands();\n" +
"local key = KEYS[1]\n" +
"local update_len = tonumber(ARGV[1])\n" +
"local key_time = 'ratetokenprefix'..key\n" +
"local curr_time_arr = redis.call('TIME')\n" +
"local nowTime = tonumber(curr_time_arr[1])\n" +
"local curr_key_time = tonumber(redis.call('get',key_time) or 0)\n" +
"local token_count = tonumber(redis.call('get',KEYS[1]) or -1)\n" +
"local token_size = tonumber(ARGV[2])\n" +
"if token_count < 0 then\n" +
"\tredis.call('set',key_time,nowTime)\n" +
"\tredis.call('set',key,token_size - 1)\n" +
"\treturn token_size -1\n" +
"else\n" +
"\tif token_count > 0 then\n" +
"\t\tredis.call('set',key,token_count - 1)\n" +
"\t\treturn token_count -1\n" +
"\telse\n" +
"\t\tif curr_key_time + update_len < nowTime then\n" +
"\t\t\tredis.call('set',key_time,nowTime)\n" +
"\t\t\tredis.call('set',key,token_size -1)\n" +
"\t\t\treturn token_size - 1\n" +
"\t\telse\n" +
"\t\t\treturn -1\n" +
"\t\tend\n" +
"\tend\n" +
"end ";
return lua;
}
lua注释
redis.replicate_commands();
-- 参数中传递的key
local key = KEYS[1]
-- 令牌桶填充 最小时间间隔
local update_len = tonumber(ARGV[1])
-- 记录 当前key上次更新令牌桶的时间的 key
local key_time = 'ratetokenprefix'..key
-- 获取当前时间(这里的curr_time_arr 中第一个是 秒数,第二个是 秒数后毫秒数),由于我是按秒计算的,这里只要curr_time_arr[1](注意:redis数组下标是从1开始的)
--如果需要获得毫秒数 则为 tonumber(arr[1]*1000 + arr[2])
local curr_time_arr = redis.call('TIME')
-- 当前时间秒数
local nowTime = tonumber(curr_time_arr[1])
-- 从redis中获取当前key 对应的上次更新令牌桶的key 对应的更新时间
local curr_key_time = tonumber(redis.call('get',key_time) or 0)
-- 获取当前key对应令牌桶中的令牌数
local token_count = tonumber(redis.call('get',KEYS[1]) or -1)
-- 当前令牌桶的容量
local token_size = tonumber(ARGV[2])
-- 令牌桶数量小于0 说明令牌桶没有初始化
if token_count < 0 then
redis.call('set',key_time,nowTime) --初始化令牌桶数量时的时间
redis.call('set',key,token_size -1)
return token_size -1
else
if token_count > 0 then --当前令牌桶中令牌数够用
redis.call('set',key,token_count - 1)
return token_count -1 --返回剩余令牌数
else --当前令牌桶中令牌数已清空
if curr_key_time + update_len < nowTime then --判断一下,当前时间秒数 与上次更新时间秒数 的间隔,是否大于规定时间间隔数 (update_len)
redis.call('set',key_time,nowTime) --重置令牌桶数量时的时间
redis.call('set',key,token_size -1) --重置令牌桶数量
return token_size - 1
else
return -1
end
end
end