阅读本文需要对基本的令牌桶算法有一定的了解。
背景
之前在服务治理的文档提到过限流,当时主要介绍了google的动态限流算法,其借鉴了传输层tcp的bbr拥塞控制算法,确实是让人眼前一亮。对于相对传统的限流算法比如漏桶、令牌桶这些就一笔带过了,因为确实比较简单,实现也不复杂。
最近在项目中涉及到了分布式限流,我们是使用了开源的lua脚本实现了利用redis的分布式令牌桶。理解了一下脚本的内容后觉得这个算法实现是很巧妙,因为记录分享一下。文末附该lua脚本。
正文
在令牌桶的实现中,通常都会维护currentTokens(当前有多少令牌)这个状态。currentTokens涉及到比较多的操作,一方面取令牌时会减少currentTokens,另一方面currentTokens会随时间均匀地增加,直至达到maxTokens(桶的上限),相对比较复杂。
例如在guava的ratelimit中需要两个变量currentTokens以及lastOpreationTime。每次取令牌时需要根据lastOpreationTime计算当前的令牌数,伪代码大概如下。
currentTokens += (time.Now()- lastOpreationTime)*tokensRecoveryPerUnitTime
currentTokens = max(currentTokens, maxTokens)
该lua脚本的巧妙之处在于将所有的状态都转换为时间,所以只需要关心时间就好。下面介绍其思路。
在time1时刻,认为此时处于初始状态,令牌桶为满的。num表示桶的大小,即maxTokens。interval表示向桶中放入一个令牌的时间间隔。
在time1我们执行get操作,取出n个令牌,get_interval = n*interval,used_interval = get_interval。此时used interval是小于bucket interval的,限流通过。
此时我们记录一个时刻time_at,time_at = now + used_interval。time_at的含义是当时间大于等于time_at时,令牌桶就是满的。
然后来到time2,经过time2-time1,我们会向桶中放入(time2-time1)/interval个令牌,used_interval会缩减time2-time1。此时的used_interval = time_at - time.Now()
time2时刻先取了x个令牌,可以看到此时是可以通过限流的,time_at更新。
time2时刻再取y个令牌,限流不通过,但是需要多久才能通过呢?
retry_after = time_at + y*interval - (now+bucket_interval)
经过上述的描述,其实只需要维护time_at这个时刻就保证所有的状态,并且在redis中可以通过redis.Call(“SET”, key, time_at, “Expire”, time_at - now )的方式维护time_at,即如果没有该key的值,就认为桶是满的。
以上。
实现来自这里
-- this script has side-effects, so it requires replicate commands mode
redis.replicate_commands()
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
local cost = tonumber(ARGV[4])
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {cost, remaining, tostring(retry_after), tostring(reset_after)}