redis + lua分布式锁

package com.jd.car.parrot.web.plugin.ratelimter;

import com.google.common.collect.Lists;
import com.car.parrot.web.plugin.config.Singleton;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * See https://stripe.com/blog/rate-limiters and
 * https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L11-L34
 * See  https://github.com/spring-cloud/spring-cloud-gateway/blob/master/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java
 * RedisRateLimiter.
 *
 * @author zero
 */
@Slf4j
public class RedisRateLimiter {

    /**
     * logger.
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(RedisRateLimiter.class);

    private RedisScript<List<Long>> script;

    private AtomicBoolean initialized = new AtomicBoolean(false);

    /**
     * Instantiates a new Redis rate limiter.
     */
    public RedisRateLimiter() {
        this.script = redisScript();
        initialized.compareAndSet(false, true);
    }

    /**
     * This uses a basic token bucket algorithm and relies on the fact that Redis scripts
     * execute atomically. No other operations can run between fetching the count and
     * writing the new count.
     *
     * @param id            is rule id
     * @param replenishRate replenishRate
     * @param burstCapacity burstCapacity
     * @return {@code Mono<Response>} to indicate when request processing is complete
     */
    @SuppressWarnings("unchecked")
    public Mono<RateLimiterResponse> isAllowed(final String id, final double replenishRate, final double burstCapacity) {
        if (!this.initialized.get()) {
            throw new IllegalStateException("RedisRateLimiter is not initialized");
        }
        try {
            long startTime = System.currentTimeMillis();
            List<String> keys = getKeys(id);
            List<String> args = getArgs(replenishRate, burstCapacity);
            Flux<List<Long>> resultFlux = Singleton
                    .INST
                    .get(ReactiveRedisTemplate.class)
                    .execute(this.script, keys, args);

            RateLimiterResponse limiterResponse = resultFlux
                    .onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
                    .reduce(new ArrayList<Long>(), (longs, l) -> {
                        longs.addAll(l);
                        return longs;
                    }).map(results -> {
                        boolean allowed = results.get(0) == 1L;
                        Long tokensLeft = results.get(1);
                        return new RateLimiterResponse(allowed, tokensLeft);
                    }).block();
            log.debug("limiterResponse:{} cost:{}", limiterResponse, (System.currentTimeMillis() - startTime));
            return Mono.just(limiterResponse);
        } catch (Exception e) {
            LOGGER.error("Error determining if user allowed from redis:", e);
        }
        return Mono.just(new RateLimiterResponse(true, -1));
    }

    private static List<String> getKeys(final String id) {
        String prefix = "request_rate_limiter.{" + id;
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        log.debug("tokenKey:{},timestampKey:{}", tokenKey, timestampKey);
        return Arrays.asList(tokenKey, timestampKey);
    }

    @SuppressWarnings("unchecked")
    private RedisScript<List<Long>> redisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("/META-INF/scripts/request_rate_limiter.lua")));
        redisScript.setResultType(List.class);
        return redisScript;
    }

    /**
     * 获取lua脚本执行参数
     *
     * 注意:虽然当前电脑普遍时钟精度只到毫秒
     *
     * @param replenishRate 令牌桶令牌生成速率
     * @param burstCapacity 令牌桶容量
     * @return lua脚本执行参数
     */
    private List<String> getArgs(final double replenishRate, final double burstCapacity) {
        Instant now = Instant.now();
        long ninthPowerOfTen = 1000 * 1000 * 1000L;
        long nowNano = now.getEpochSecond() * ninthPowerOfTen + now.getNano();

        ArrayList<String> scriptArgs = Lists.newArrayListWithCapacity(4);

        scriptArgs.add(replenishRate + "");
        scriptArgs.add(burstCapacity + "");
        scriptArgs.add(nowNano + "");
        scriptArgs.add("1");

        return scriptArgs;
    }
}

local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

-- 填满需要的时间
local fill_time = capacity / rate
-- 过期时间
local ttl = math.floor(fill_time * 2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

-- 上一次获取token后剩余数量
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
    last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
-- 上一次生成token的时间
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
    last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

-- 时间增量
local delta = math.max(0, now - last_refreshed)
-- 已填充的令牌,min(最大令牌值,增量令牌)

-- 纳秒填充速率(实际精度未毫秒)
local nano_rate =  rate / 1000000000
-- 间隔期内需要一次性填充的令牌数--整数
local delta_tokens = math.floor(delta * nano_rate)
-- 总令牌数量
local total_tokens = last_tokens + delta_tokens
-- 实际可用令牌数量(多余的流走了)
local filled_tokens = math.floor(math.min(capacity, total_tokens))
-- 已填充的令牌是否满足请求的要求
local allowed = filled_tokens >= requested
-- 要补充到redis里的令牌
local new_tokens = filled_tokens
-- 是否允许获取到令牌
local allowed_num = 0
if allowed then
    -- 要补充到redis里的令牌
    new_tokens = filled_tokens - requested
    allowed_num = 1
end

redis.log(redis.LOG_WARNING, "delta " .. delta)
redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
redis.log(redis.LOG_WARNING, "ttl " .. ttl)

redis.call("setex", tokens_key, ttl, new_tokens)
-- 如果没有生成令牌,令牌生成时间就不改变
if delta_tokens > 0 then
   redis.call("setex", timestamp_key, ttl, now)
else
   redis.call("setex", timestamp_key, ttl, last_refreshed)
end

return { allowed_num, new_tokens }

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值