用令牌桶的思路自定义一个CurrentLimiter

前言

最近在实习过程中,导师让我给一个在分布式环境下的接口做限流,自己上网查了资料接的使用令牌桶的思想应该可以实现,如果大家对令牌桶不懂,可以查看我之前写过的文章。
对接口进行限流的四个基本算法

思路

  • 这次使用的是redis,把桶设置在redis中保证分布式环境下一个接口只有一个令牌桶
  • 用拦截器的方式,如果能获取到令牌,则通过,如果不能,则拦截
  • 设置成注解,使得每一个不同的接口可以有不同的限流
  • 增加总的开关,使得在不想要限流的情况下可以一键关闭

Redis:作为Lock和令牌桶

  • 此次使用的是JedisPool
  • 此次用Redis有2个作用:
    • 有些并发情况下加锁
    • 作为令牌桶
package com.xiaoxiao.current_limiting.util;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.params.SetParams;

@Slf4j
@Service
public class JedisPoolService {

    @Autowired
    private JedisPool jedisPool;

    /**
     * 获取锁,失败直接返回
     */
    public boolean tryLockFailed(String lock, String value) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            //SetParams.setParams().nx() 如果不存在才设置
            //如果成功res:OK  失败:null
            String res = jedis.set(lock, value, SetParams.setParams().nx());
            return res == null ? false : true;
        } finally {
            close(jedis);
        }
    }

    /**
     * 获取锁,失败直接返回
     */
    public boolean tryLockFailed(String lock, String value, int expired) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            //.ex(expired) 增加一个过期时间,防止忘记del而导致死锁。
            String res = jedis.set(lock, value, SetParams.setParams().nx().ex(expired));
            return res == null ? false : true;
        } finally {
            close(jedis);
        }
    }

    /**
     * 获得锁,获取不到就阻塞
     */
    public boolean tryLock(String lock, String value) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            String res = jedis.set(lock, value, SetParams.setParams().nx());
            while (res == null) {
                res = jedis.set(lock, value, SetParams.setParams().nx());
            }
            return true;
        } finally {
            close(jedis);
        }
    }

    /**
     * 获得锁,获取不到就阻塞
     */
    public boolean tryLock(String lock, String value, int expired) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            String res = jedis.set(lock, value, SetParams.setParams().nx().ex(expired));

            while (res == null) {
                res = jedis.set(lock, value, SetParams.setParams().nx().ex(expired));
            }
            return true;
        } finally {
            close(jedis);
        }
    }

    /**
     * 释放锁
     */
    public void release(String lock) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            jedis.del(lock);
        } finally {
            close(jedis);
        }
    }

	//获取令牌
    public Long decrBucket(String Bucket) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.decrBy(Bucket, 1);
        } finally {
            close(jedis);
        }
    }

	//添加令牌
    public Long incrBucket(String bucket) {
        return incrBucket(bucket, 1L);
    }

    public Long incrBucket(String bucket, Long rate) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.incrBy(bucket, rate);
        } finally {
            close(jedis);
        }
    }

	//获取剩余的令牌数
    public Long getToken(String Bucket) {
        return Long.parseLong(get(Bucket));
    }

    public String get(String key) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            return jedis.get(key);
        } finally {
            close(jedis);
        }
    }

    public void setIfAbsent(String key, String value) {
        Jedis jedis = null;
        try {
            jedis = getJedis();
            jedis.setnx(key, value);
        } finally {
            close(jedis);
        }
    }
    //
    //    //序列化
    //    public String objectSerializable(Object obj) {
    //        String objStr;
    //        ByteArrayOutputStream byteArrayOutputStream = null;
    //        ObjectOutputStream objectOutputStream = null;
    //        try {
    //            byteArrayOutputStream = new ByteArrayOutputStream();
    //            objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
    //            objectOutputStream.writeObject(obj);
    //            objStr = byteArrayOutputStream.toString("UTF-8");
    //
    //        } catch (IOException e) {
    //            log.error(ExceptionUtils.getStackTrace(e));
    //            return null;
    //        } finally {
    //            try {
    //                if (objectOutputStream != null) {
    //                    objectOutputStream.close();
    //                }
    //                if (byteArrayOutputStream != null) {
    //                    byteArrayOutputStream.close();
    //                }
    //            } catch (Exception e) {
    //                log.error(ExceptionUtils.getStackTrace(e));
    //                return null;
    //            }
    //        }
    //
    //        return objStr;
    //    }
    //
    //    //反序列化
    //    public Object objectDeserializable(String objStr) {
    //        Object obj;
    //        ByteArrayInputStream inputStream = null;
    //        ObjectInputStream objectInputStream = null;
    //        try {
    //            inputStream = new ByteArrayInputStream(objStr.getBytes("UTF-8"));
    //            objectInputStream = new ObjectInputStream(inputStream);
    //            obj = objectInputStream.readObject();
    //        } catch (Exception e) {
    //            log.error(ExceptionUtils.getStackTrace(e));
    //            return null;
    //        } finally {
    //            try {
    //                if (objectInputStream != null) {
    //                    objectInputStream.close();
    //                }
    //                if (inputStream != null) {
    //                    inputStream.close();
    //                }
    //            } catch (Exception e) {
    //                log.error(ExceptionUtils.getStackTrace(e));
    //                return null;
    //            }
    //        }
    //        return obj;
    //    }

    private Jedis getJedis() {
        return jedisPool.getResource();
    }

    private void close(Jedis jedis) {
        if (jedis == null) {
            return;
        }
        jedis.close();
    }
}

实现CurrentLimiter的核心代码

package com.xiaoxiao.current_limiting.entity;

import java.io.Serializable;
import java.util.Date;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import com.xiaoxiao.current_limiting.util.JedisPoolService;
import com.xiaoxiao.current_limiting.util.SpringContextUtil;

import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@NoArgsConstructor
@Slf4j
public class RateLimiterCloud implements RateLimiter, Serializable {
    private long lastUpdateTime = new Date().getTime(); //上次令牌桶更新的时间
    private long QPS;  //QPS
    private String bucket; //"currentLimit$Bucket:Controller:MethodName" 令牌桶的名称
    private boolean fastFailed; //是否获取不到就退出
    private long period; //微秒 周期
    private long leaderExpire = 1000;//leader过期时间 1s
    private int lockExpire = 5;
    private String LOCK_PUT;
    private String LOCK_GET;
    private static JedisPoolService jedisPoolService = SpringContextUtil.getBean(JedisPoolService.class); //因为本次的RateLimiterCloud是采用new的形式,所以不能靠@Autowared注入,不让会返回null
    private String AppCode =
            SpringContextUtil.getApplicationName() + SpringContextUtil.getApplicationPort() + this.hashCode();
    //实例的唯一标示 因为是分布式,所以一个接口会有多个RateLimiterCloud实例

    public void updateParams(long QPS) { //每次重启可能QPS会改变,所以这里做了更新,在每次访问的时候调用
        this.QPS = QPS;
        this.period = 1000 * 1000 / QPS;
    }

    public RateLimiterCloud(String bucket, long QPS, boolean fastFailed, String lockMethod) {
        this.QPS = QPS;
        this.period = 1000 * 1000 / QPS;
        this.fastFailed = fastFailed;
        this.bucket = bucket;
        this.LOCK_PUT = lockMethod + "$PUT";
        this.LOCK_GET = lockMethod + "$GET";
        log.info("CreateRateLimiterCloud--->OPS:{},fastFailed:{},bucket:{},period:{}", QPS, fastFailed, bucket, period);
        initBuck();
        updateTokenSchedule(); //创建实例的时候也开启定时任务
    }

    //初始化Buck
    private void initBuck() {
        jedisPoolService.setIfAbsent(bucket, "0");
    }

    /**
     * 定时更新令牌数
     * 分布式,选取一个作为leader来开启定时任务,其他的实例负责监听,如果在一段时间监听到没有添加令牌,就开启抢占leader位置
     */
    private void updateTokenSchedule() {
        ScheduledExecutorService service = Executors.newScheduledThreadPool(1);
        log.info(bucket + ":updateTokenSchedule start...");
        service.scheduleAtFixedRate(() -> {
            if (jedisPoolService.tryLockFailed(LOCK_PUT, AppCode) || AppCode
                    .equalsIgnoreCase(jedisPoolService.get(LOCK_PUT))) { //更新
                Long token = jedisPoolService.getToken(bucket);
                if (QPS > token) {
                    jedisPoolService.incrBucket(bucket);
                }
            } else {
                if (System.currentTimeMillis() - lastUpdateTime > leaderExpire) {
                    jedisPoolService.release(LOCK_PUT);
                }
            }
        }, 0, period, TimeUnit.MICROSECONDS);

    }

    //获取令牌
    @Override
    public boolean tryAcquire() {
        if (fastFailed) {
            return tryAcquireFailed();
        } else {
            //TODO 这里还没有写 本次只写了 快速失败
            return true;

        }
    }

    //获取令牌
    @Override
    public boolean tryAcquireFailed() {
        try {
            jedisPoolService.tryLock(LOCK_GET, LOCK_GET, lockExpire); //lock
            if (jedisPoolService.getToken(bucket) > 0) { //令牌数大于0
                jedisPoolService.decrBucket(bucket);
                return true;
            } else {
                return false;
            }
        } finally {
            jedisPoolService.release(LOCK_GET);
        }

    
}

创建CurrentLimiter注解

package com.xiaoxiao.current_limiting.entity;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
@Inherited
public @interface CurrentLimiter {

    long QPS() default 10000;

    boolean fastFailed() default true;

    boolean init() default false;
}

创建拦截器

package com.xiaoxiao.current_limiting.interception;

import java.util.HashMap;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;

import com.xiaoxiao.current_limiting.entity.CurrentLimiter;
import com.xiaoxiao.current_limiting.entity.RateLimiterCloud;
import com.xiaoxiao.current_limiting.util.JedisPoolService;

import lombok.extern.slf4j.Slf4j;

@Slf4j
@Component
public class CurrentLimitInterception implements HandlerInterceptor {

    @Autowired
    private JedisPoolService jedisPoolService;

    private Map<String, RateLimiterCloud> rateLimiterCloudMap = new HashMap<>();

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        if (!(handler instanceof HandlerMethod)) {
            return true;
        }

        HandlerMethod handlerMethod = (HandlerMethod) handler;
        CurrentLimiter currentLimiter = handlerMethod.getMethodAnnotation(CurrentLimiter.class);

        //如果没有注解 就不限流
        if (currentLimiter == null) {
            log.info("没有注解,不限流");
            return true;
        } else { //限流
            long QPS = currentLimiter.QPS();
            boolean fastFailed = currentLimiter.fastFailed();
            //"className:methodName"
            String lock = handlerMethod.getBeanType().getName() + ":" + handlerMethod.getMethod().getName();
            RateLimiterCloud res;
            //查看这个方法已经是否有令牌桶,如果没有,初始化
            res = rateLimiterCloudMap.get("CurrentLimiter$Need:" + lock);
            if (ObjectUtils.isEmpty(res)) {
                boolean getLock = jedisPoolService.tryLockFailed(lock, lock, 5);
                if (getLock) {
                    res = new RateLimiterCloud(lock + "$Bucket", QPS, fastFailed, lock);
                    rateLimiterCloudMap.put("CurrentLimiter$Need:" + lock, res);
                    jedisPoolService.release(lock);
                } else {
                    return false;
                }
            } else {
                res.updateParams(QPS);

            }
            return res.tryAcquire();

        }
    }
}

注册拦截器

package com.xiaoxiao.current_limiting.interception;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@Component
@ConditionalOnProperty(prefix = "current-limiting",name = "enabled",havingValue = "true") //添加一键开启功能,如果在配置文件中current-limiting.enabled=true这个类才会起效
public class CurrentLimitInterceptorConfigure implements WebMvcConfigurer {
    @Autowired
    private CurrentLimitInterception currentLimitInterception;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(currentLimitInterception);
    }
}

验证

  • 创建一个Controller
package com.xiaoxiao.current_limiting.controller;

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import com.xiaoxiao.current_limiting.entity.CurrentLimiter;

import lombok.extern.slf4j.Slf4j;

@Slf4j
@RestController
public class CurrentLimiterDemoController {

    @CurrentLimiter(QPS = 10)
    @RequestMapping(value = "/test")
    public void testCurrentLimit() {
        System.out.println(System.currentTimeMillis());

    }
}

用jmeter测试,控制台打印出结果

...
1579507181 260
1579507181 372
1579507181 444
1579507181 550
1579507181 649
1579507181 755
1579507111 849
....

可以看到差不多0.1s通过一个请求

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 我可以给你一个参考: #include <stdio.h> int main() { int tokens; // 当前令牌数 int rate; // 每秒添加令牌的速率 int capacity; // 令牌桶容量 int time; // 请求令牌的时间 // 初始化令牌桶 tokens = 0; rate = 1; capacity = 10; // 请求令牌 time = 5; if (time * rate + tokens <= capacity) { tokens += time * rate; } else { tokens = ### 回答2: 令牌桶算法是一种流量控制算法,用于限制发送请求或数据的速率。下面是一个用C语言实现令牌桶算法的简单示例: ```c #include <stdio.h> #include <stdlib.h> #include <time.h> #define TOKEN_RATE 10 // 令牌产生速率,每秒产生10个令牌 #define TOKEN_BUCKET_SIZE 20 // 令牌桶容量,最多存放20个令牌 int main() { int token_bucket = TOKEN_BUCKET_SIZE; // 初始化令牌桶 time_t last_token_time = time(NULL); // 上次取出令牌的时间 while(1) { // 获取当前时间 time_t current_time = time(NULL); // 计算两次获取令牌的时间差,即已经过去的时间 double elapsed_time = difftime(current_time, last_token_time); // 根据令牌产生速率,更新令牌桶中的令牌数量 int tokens_to_add = TOKEN_RATE * elapsed_time; token_bucket = token_bucket + tokens_to_add; // 限制令牌桶中的令牌数量不超过最大容量 if(token_bucket > TOKEN_BUCKET_SIZE) { token_bucket = TOKEN_BUCKET_SIZE; } // 读取一个请求 printf("请输入一个请求(按'q'退出):"); char input = getchar(); // 判断请求是否合法 if(input == 'q') { break; // 输入q表示退出程序 } // 判断令牌桶中是否有足够的令牌 if(token_bucket > 0) { token_bucket--; printf("处理请求:%c\n", input); } else { printf("令牌桶为空,请求被丢弃!\n"); } // 更新上次取出令牌的时间 last_token_time = current_time; fflush(stdin); } return 0; } ``` 在上述示例中,令牌桶的容量(`TOKEN_BUCKET_SIZE`)被设置为20个令牌,令牌产生速率(`TOKEN_RATE`)为每秒10个令牌。程序会在每次循环中检查令牌桶中是否有足够的令牌来处理请求,如果有,则取出一个令牌并处理请求;否则,丢弃请求。程序会持续运行直到用户输入`q`退出。 注意:这只是令牌桶算法的一个简单示例,实际使用中可能需要根据具体需求进行调整。 ### 回答3: 令牌桶算法(Token Bucket Algorithm)是一种用于流量控制的算法,通过限制请求或数据包的速率来平滑服务和资源的访问。以下是一个用C语言实现令牌桶算法的示例: ```c #include <stdio.h> #include <time.h> #define TOKEN_RATE 10 // 每秒放入的令牌数量 #define TOKEN_CAPACITY 20 // 令牌桶容量 int main() { int tokenCount = 0; time_t prevTime = time(NULL); while (1) { time_t currentTime = time(NULL); double elapsedSec = difftime(currentTime, prevTime); tokenCount += (TOKEN_RATE * elapsedSec); // 添加令牌数量 if (tokenCount > TOKEN_CAPACITY) { tokenCount = TOKEN_CAPACITY; // 限制令牌数量不能超过桶的容量 } if (tokenCount > 0) { // 处理请求或传输数据 printf("通过令牌桶,处理请求或传输数据\n"); tokenCount--; } else { // 令牌桶为空,拒绝请求或暂停传输 printf("令牌桶为空,拒绝请求或暂停传输\n"); } prevTime = currentTime; } return 0; } ``` 上述代码中,`TOKEN_RATE`定义了每秒放入的令牌数量,`TOKEN_CAPACITY`定义了令牌桶的容量。算法的关键在于通过计算时间间隔,根据每秒放入的令牌数量来添加令牌数量,然后根据令牌数量的多少来判断是否可以处理请求或传输数据。如果令牌大于0,则可以处理请求或传输数据,并将令牌数量减1;如果令牌为0,则令牌桶为空,需要拒绝请求或暂停传输。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值