最完整清晰的redis+lua脚本+令牌桶算法 实现限流

最完整清晰的redis+ lua脚本 + 令牌桶算法 实现限流控制

在网上看了好多博客,感觉不是很清楚,于是决定自己手撸一个。

一、自定义一个注解,用来给限流的方法标注
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
    //限流唯一标示
    String key() default "";

    //限流单位时间(单位为s)
    int time() default 1;

    //单位时间内限制的访问次数
    int count();

    //是否限制ip
    boolean ipLimit() default false;
}
二、编写lua脚本

重要的地方注释得非常详细了,这里就不多解释;

主要功能是:

根据key(参数) 查询 对应的 value(令牌数)
	如果为null 说明该key 是第一次进入 
	{
		初始化 令牌桶(参数)数量;记录初始化时间 ->返回 剩余令牌数
	} 
	
	如果不为null
	{
		判断 value 是否大于1 
		{
			大于1  ->value - 1  -> 返回 剩余令牌数
			小于1  -> 判断  补充令牌时间间隔是否足够
			{
				足够 -> 补充令牌;更新补充令牌时间-> 返回 剩余令牌数
				不足够	-> 返回 -1 (说明超过限流访问次数)
			}
		}
	}
	
redis.replicate_commands();
-- 参数中传递的key
local key = KEYS[1]
-- 令牌桶填充 最小时间间隔
local update_len = tonumber(ARGV[1])
-- 记录 当前key上次更新令牌桶的时间的 key
local key_time = 'ratetokenprefix'..key
-- 获取当前时间(这里的curr_time_arr 中第一个是 秒数,第二个是 秒数后毫秒数),由于我是按秒计算的,这里只要curr_time_arr[1](注意:redis数组下标是从1开始的)
--如果需要获得毫秒数 则为 tonumber(arr[1]*1000 + arr[2])
local curr_time_arr = redis.call('TIME')
-- 当前时间秒数
local nowTime = tonumber(curr_time_arr[1])
-- 从redis中获取当前key 对应的上次更新令牌桶的key 对应的value
local curr_key_time = tonumber(redis.call('get',KEYS[1]) or 0)
-- 获取当前key对应令牌桶中的令牌数
local token_count = tonumber(redis.call('get',KEYS[1]) or -1)
-- 当前令牌桶的容量
local token_size = tonumber(ARGV[2])
-- 令牌桶数量小于0 说明令牌桶没有初始化
if token_count < 0 then
	redis.call('set',key_time,nowTime)
	redis.call('set',key,token_size -1)
	return token_size -1
else
	if token_count > 0 then --当前令牌桶中令牌数够用
		redis.call('set',key,token_count - 1)
		return token_count -1   --返回剩余令牌数
	else    --当前令牌桶中令牌数已清空
		if curr_key_time + update_len < nowTime then    --判断一下,当前时间秒数 与上次更新时间秒数  的间隔,是否大于规定时间间隔数 (update_len)
			redis.call('set',key,token_size -1)
			return token_size - 1
		else
			return -1
		end
	end
end
三、读取lua脚本
@Component
public class CommonConfig {
    /**
     * 读取限流脚本
     */
    @Bean
    public DefaultRedisScript<Number> redisluaScript() {
        DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>();
        //这里脚本的路径为path for source root 路径
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("myLua.lua"))); 
        redisScript.setResultType(Number.class);
        return redisScript;
    }
    /**
     * RedisTemplate
     */
    @Bean
    public RedisTemplate<String, Serializable> limitRedisTemplate(LettuceConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Serializable> template = new RedisTemplate<String, Serializable>();
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        template.setConnectionFactory(redisConnectionFactory);
        return template;
    }
}
四、创建拦截器拦截带有该注解的方法
@Component
public class RateLimitInterceptor implements HandlerInterceptor {
    private final Logger LOG = LoggerFactory.getLogger(this.getClass());
    
    @Autowired
    private RedisTemplate<String, Serializable> limitRedisTemplate;

    @Autowired
    private DefaultRedisScript<Number> redisLuaScript;
    
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        assert handler instanceof HandlerMethod;
        HandlerMethod method = (HandlerMethod) handler;
        RateLimit rateLimit = method.getMethodAnnotation(RateLimit.class);
        //当前方法上有我们自定义的注解
        if (rateLimit != null) {
            //获得单位时间内限制的访问次数
            int count = rateLimit.count();
            String key = rateLimit.key();
            //获得限流单位时间(单位为s)
            int time = rateLimit.time();
            boolean ipLimit = rateLimit.ipLimit();
            //拼接 redis中的key
            StringBuilder sb = new StringBuilder();
            sb.append(Constants.RATE_LIMIT_KEY).append(key).append(":");
            //如果需要限制ip的话
            if(ipLimit){
                sb.append(getIpAddress(request)).append(":");
            }
            List<String> keys = Collections.singletonList(sb.toString());
           //执行lua脚本
            Number execute = limitRedisTemplate.execute(redisLuaScript, keys, time, count);
            assert execute != null;
            if (-1 == execute.intValue()) {
                ResultModel resultModel = ResultModel.error_900("接口调用超过限流次数");
                response.setStatus(901);
                response.setCharacterEncoding("utf-8");
                response.setContentType("application/json");
                response.getWriter().write(JSONObject.toJSONString(resultModel));
                response.getWriter().flush();
                response.getWriter().close();
                LOG.info("当前接口调用超过时间段内限流,key:{}", sb.toString());
                return false;
            } else {
                LOG.info("当前访问时间段内剩余{}次访问次数", execute.toString());
            }
        }
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
    
    public static String getIpAddr(HttpServletRequest request) {
        String ipAddress = null;
        try {
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getRemoteAddr();
            }
            // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
            // "***.***.***.***".length()
            if (ipAddress != null && ipAddress.length() > 15) { 
                // = 15
                if (ipAddress.indexOf(",") > 0) {
                    ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
                }
            }
        } catch (Exception e) {
            ipAddress = "";
        }
        return ipAddress;
    }

}
一个自定义的常量

用作redis前缀

public class Constants {
    public static final String RATE_LIMIT_KEY = "rateLimit:";
}
五、在WebConfig中注册这个这个拦截器
@Configuration
@EnableWebMvc
public class WebConfig extends WebMvcConfigurerAdapter {

    @Autowired
    private RateLimitInterceptor rateLimitInterceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(rateLimitInterceptor);
        super.addInterceptors(registry);
    }
}
六、注解使用
@RestController
@RequestMapping(value = "/test")
public class TestController {

    //限流规则为 1秒内只允许同一个ip发送5次请求
    @RateLimit(key = "testGet",time = 1,count = 5,ipLimit = true)
    @RequestMapping(value = "/get")
    public ResultModel testGet(){
        return ResultModel.ok_200();
    }

}

如果觉得有问题,欢迎各位大佬指正
觉得可以的话点个赞再走吧!!!!!!

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
实现分布式限流可以使用 RedisLua 脚本来完成。以下是可能的实现方案: 1. 使用 Redis 的 SETNX 命令来实现基于令牌算法限流 令牌算法是一种常见的限流算法,它可以通过令牌的放置和消耗来控制流量。在 Redis 中,我们可以使用 SETNX 命令来实现令牌算法。 具体实现步骤如下: - 在 Redis 中创建一个有序集合,用于存储令牌桶的令牌数量和时间戳。 - 每当一个请求到达时,我们首先获取当前令牌桶中的令牌数量和时间戳。 - 如果当前时间戳与最后一次请求的时间戳之差大于等于令牌桶中每个令牌的发放时间间隔,则将当前时间戳更新为最后一次请求的时间戳,并且将令牌桶中的令牌数量增加相应的数量,同时不超过最大容量。 - 如果当前令牌桶中的令牌数量大于等于请求需要的令牌数量,则返回 true 表示通过限流,将令牌桶中的令牌数量减去请求需要的令牌数量。 - 如果令牌桶中的令牌数量不足,则返回 false 表示未通过限流。 下面是使用 RedisLua 脚本实现令牌算法的示例代码: ```lua -- 限流的 key local key = KEYS[1] -- 令牌桶的容量 local capacity = tonumber(ARGV[1]) -- 令牌的发放速率 local rate = tonumber(ARGV[2]) -- 请求需要的令牌数量 local tokens = tonumber(ARGV[3]) -- 当前时间戳 local now = redis.call('TIME')[1] -- 获取当前令牌桶中的令牌数量和时间戳 local bucket = redis.call('ZREVRANGEBYSCORE', key, now, 0, 'WITHSCORES', 'LIMIT', 0, 1) -- 如果令牌桶为空,则初始化令牌桶 if not bucket[1] then redis.call('ZADD', key, now, capacity - tokens) return 1 end -- 计算当前令牌桶中的令牌数量和时间戳 local last = tonumber(bucket[2]) local tokensInBucket = tonumber(bucket[1]) -- 计算时间间隔和新的令牌数量 local timePassed = now - last local newTokens = math.floor(timePassed * rate) -- 更新令牌桶 if newTokens > 0 then tokensInBucket = math.min(tokensInBucket + newTokens, capacity) redis.call('ZADD', key, now, tokensInBucket) end -- 检查令牌数量是否足够 if tokensInBucket >= tokens then redis.call('ZREM', key, bucket[1]) return 1 else return 0 end ``` 2. 使用 RedisLua 脚本实现基于漏桶算法限流 漏桶算法是另一种常见的限流算法,它可以通过漏桶的容量和漏水速度来控制流量。在 Redis 中,我们可以使用 Lua 脚本实现漏桶算法。 具体实现步骤如下: - 在 Redis 中创建一个键值对,用于存储漏桶的容量和最后一次请求的时间戳。 - 每当一个请求到达时,我们首先获取当前漏桶的容量和最后一次请求的时间戳。 - 计算漏水速度和漏水的数量,将漏桶中的容量减去漏水的数量。 - 如果漏桶中的容量大于等于请求需要的容量,则返回 true 表示通过限流,将漏桶中的容量减去请求需要的容量。 - 如果漏桶中的容量不足,则返回 false 表示未通过限流。 下面是使用 RedisLua 脚本实现漏桶算法的示例代码: ```lua -- 限流的 key local key = KEYS[1] -- 漏桶的容量 local capacity = tonumber(ARGV[1]) -- 漏水速度 local rate = tonumber(ARGV[2]) -- 请求需要的容量 local size = tonumber(ARGV[3]) -- 当前时间戳 local now = redis.call('TIME')[1] -- 获取漏桶中的容量和最后一次请求的时间戳 local bucket = redis.call('HMGET', key, 'capacity', 'last') -- 如果漏桶为空,则初始化漏桶 if not bucket[1] then redis.call('HMSET', key, 'capacity', capacity, 'last', now) return 1 end -- 计算漏水的数量和漏桶中的容量 local last = tonumber(bucket[2]) local capacityInBucket = tonumber(bucket[1]) local leak = math.floor((now - last) * rate) -- 更新漏桶 capacityInBucket = math.min(capacity, capacityInBucket + leak) redis.call('HSET', key, 'capacity', capacityInBucket) redis.call('HSET', key, 'last', now) -- 检查容量是否足够 if capacityInBucket >= size then return 1 else return 0 end ``` 以上是使用 RedisLua 脚本实现分布式限流的两种方案,可以根据实际需求选择适合的方案。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值