spring aop基于redis的令牌桶和漏桶限流

说明: 基于spring AOP的切面限流,用redis lua脚本来保持原子性,支持全局,单个应用,用户IP,甚至是el表达式解析入参中的某个字段作为限流的对象,

限流单位:支持每秒,每分,每时,或自然分,自然日等等时间单位的限流规则

限流类型:支持令牌桶和漏桶算法

支持异常时,返还扣减的次数

支持多个限流组合,例如每日访问多少日+每周或每年访问的次数等等的组合使用,可以设置优先级和是否支持异常时回退次数等

1.编写限流单位枚举类

package com.around.common.utils.limitrate;

import java.util.Calendar;

/**
 * @description: 限流单位
 * @author: moodincode
 * @create 2021/1/8
 */
public enum LimitUnitEnum {
    /**每相当于非自然时间段,即从当前时间起的时间间隔**/
    PER_SECOND(0,"PER_SECOND","每秒"),
    PER_MINUTES(0,"PER_MINUTES","每分"),
    PER_HOUR(0,"PER_HOUR","每时"),
    PER_DAY(0,"PER_DAY","每日"),
    PER_WEEK(0,"PER_WEEK","每周"),
    PER_MONTH(0,"PER_MONTH","每月"),
    PER_YEAR(0,"PER_YEAR","每年"),
    /**自然时间,即该时间的开始时间,例如自然月,即今日是1月31 则2月1日重置***/
    MINUTES(1,"MINUTES","自然分"),
    HOUR(1,"HOUR","自然时"),
    DAY(1,"DAY","自然日"),
    WEEK(1,"WEEK","自然周"),
    MONTH(1,"MONTH","自然月"),
    YEAR(1,"YEAR","自然年"),

    ;
    private Integer type;
    private String code;
    private String name;

    public Integer getType() {
        return type;
    }

    public void setType(Integer type) {
        this.type = type;
    }

    public String getCode() {
        return code;
    }

    public void setCode(String code) {
        this.code = code;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    LimitUnitEnum(Integer type,String code, String name) {
        this.type = type;
        this.code = code;
        this.name = name;
    }

    public Long getExpireTs(){
        return getExpireTs(System.currentTimeMillis());
    }
    public Long getExpireTs(Long compareTime){
        switch (this){
            case PER_SECOND:
                return compareTime+1000L;
            case PER_MINUTES:
                return  compareTime+60000L;
            case PER_HOUR:
                return compareTime+60*60000L;
            case PER_DAY:
                return  compareTime+24*60*60000L;
            case PER_WEEK:
                return compareTime+7*24*60*60000L;
            case PER_MONTH:
                return compareTime+30*24*60*60000L;
            case PER_YEAR:
                return compareTime+365*24*60*60000L;
            default:
                return getNaturalTime(compareTime);
        }

    }

    /**
     * 获取自然时间
     * @param compareTime
     * @return
     */
    private Long getNaturalTime(Long compareTime) {
        Calendar c=Calendar.getInstance();
        c.setTimeInMillis(compareTime);
        //不需要break
        switch (this){
            case YEAR:
                c.set(Calendar.MONTH,11);
            case MONTH:
                c.set(Calendar.DAY_OF_MONTH,c.getActualMaximum(Calendar.DAY_OF_MONTH));
            case WEEK:
                if(WEEK.equals(this)){
                    int week = c.get(Calendar.DAY_OF_WEEK);
                    if(week>1){
                        c.add(Calendar.DATE,8-week);
                    }
                }
            case DAY:
                c.set(Calendar.HOUR,23);
            case HOUR:
                c.set(Calendar.MINUTE,59);
            case MINUTES:
                c.set(Calendar.SECOND,59);
        }
        c.set(Calendar.MILLISECOND,999);
        return c.getTimeInMillis();
    }

}

2.编写限流类型枚举类

package com.around.common.utils.limitrate;

/**
 * @program: com-around
 * @description:
 * @author: moodincode
 * @create: 2021/1/08
 **/
public enum LimitTypeEnum {


    /**
     * 单IP-令牌桶算法-指请求用户的IP
     */
    IP_TOKEN(0,"IP_TOKEN","单IP-令牌桶算法"),
    /**
     * 单应用-令牌桶算法 -指部署的应用
     */
    APP_TOKEN(0,"APP_TOKEN","单应用-令牌桶算法"),
    /**
     * 全局-令牌桶算法  -全部
     */
    GLOBAL_TOKEN(0,"GLOBAL_TOKEN","全局-令牌桶算法"),
    /**
     * 单IP-漏桶算法-指请求用户的IP
     */
    IP_LEAKY(1,"IP_LEAKY","单IP-漏桶算法"),
    /**
     * 单应用-漏桶算法 -指部署的应用
     */
    APP_LEAKY(1,"APP_LEAKY","单应用-漏桶算法"),
    /**
     * 全局-漏桶算法  -全部
     */
    GLOBAL_LEAKY(1,"GLOBAL_LEAKY","全局-漏桶算法"),
    ;
    private int code;
    private String name;
    private String decr;

    LimitTypeEnum(int code, String name, String decr) {
        this.code = code;
        this.name = name;
        this.decr = decr;
    }

    public int getCode() {
        return code;
    }

    public String getName() {
        return name;
    }

    public String getDecr() {
        return decr;
    }



}

3.编写限流通用接口

package com.around.common.utils.limitrate;

/**
 * @description: 限流提供类
 * @author: moodincode
 * @create 2021/1/8
 */
public interface LimitRateProvider {
    /**
     * 消费指定数量的key
     * @param key 对应key
     * @param num 消费的数量
     * @param ts 最后一次消费时间 距离1970毫秒数
     * @param expire 令牌桶重置时间,多久没操作会重置,时间毫秒数
     * @param rate 速度,令牌token生成的速度或漏桶流速
     * @param interval 间隔,生成令牌的时间间隔,为0则直接增加单位速度量
     * @param capacity 令牌桶或漏桶的最大容量,-1不限制,
     * @param type 限流类型 0-令牌桶,1-漏桶
     * @param compareTime 时间间隔比较时间,为0时默认当前时间 毫秒数
     * @return
     */
   Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime);

        /**
         * 查询当前容量,不涉及加减数量
         * @param key 对应key
         * @return 如果不存在返回,或数量为0均返回0,其他 大于0
         */
    Long getCurrentCount(String key);

    /**
     * 增加容量
     * @param key 对应key
     * @param num 增加的数量
     * @param capacity 容量限制  0不限制,大于0 限制
     * @param type 限流类型 0-令牌桶,1-漏桶
     * @return
     */
    Long addCount(String key,long num,long capacity,int type);
}

4.编写redis序列化工具类

package com.around.common.utils.limitrate;

import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.SerializationException;
import java.nio.charset.StandardCharsets;

/**
 * @description:
 * @author: moodincode
 * @create 2021/1/8
 */
public class LongStringSerialize implements RedisSerializer<Long> {
    @Override
    public byte[] serialize(Long number) throws SerializationException {

        return String.valueOf(number).getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public Long deserialize(byte[] bytes) throws SerializationException {
        return Long.valueOf(new String(bytes,StandardCharsets.UTF_8));
    }
}

5.编写redis限流实现类和lua脚本

package com.around.common.utils.limitrate;

import com.alibaba.fastjson.support.spring.GenericFastJsonRedisSerializer;
import com.google.common.collect.Lists;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.util.List;

/**
 * @description:
 * @author: moodincode
 * @create 2021/1/8
 */
@Component
public class RedisLimitRateProvider implements LimitRateProvider {
    @Resource
    private RedisTemplate<String,Object> redisTemplate;
      private static final String CONSUME_TOKEN_SCRIP="local r_key = tostring(KEYS[1]);\n" +
              "local num = tonumber(ARGV[1]);\n" +
              "local interval = tonumber(ARGV[2]);\n" +
              "local rate = tonumber(ARGV[3]);\n" +
              "local capacity = tonumber(ARGV[4]);\n" +
              "local compareTime = tonumber(ARGV[5]);\n" +
              "local ts = tonumber(ARGV[6]);\n" +
              "local expire = tonumber(ARGV[7]);\n" +
              "local haskey=redis.call('exists', r_key);" +
              "if tonumber(haskey)==1 then\n " +
              "    local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
              "    local count = redis.call('hget', r_key, 'count');\n" +
              "    local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
              "    if -factor > 0 then\n" +
              "        local factor = 0 \n" +
              "    end;\n" +
              "    local add_count = factor * rate;\n" +
              "    if add_count > 0 then\n" +
              "        count = count + add_count;\n" +
              "        if capacity > 0 and count > capacity then\n" +
              "            count = capacity;\n" +
              "        end;\n" +
              "        redis.call('hset', r_key, 'count', count);\n" +
              "        redis.call('hset', r_key, 'gen_ts', ts);\n" +
              "        redis.call('pexpireat', r_key, expire);\n" +
              "    end; \n" +
              "    if (count - num) > -1 then\n" +
              "        redis.call('hset', r_key, 'count', count - num);\n" +
              "        redis.call('hset', r_key, 'get_ts', ts);\n" +
              "        return (count - num);\n" +
              "    else \n" +
              "        return -1;\n" +
              "    end;\n" +
              "else\n " +
              "    redis.call('hset', r_key, 'gen_ts', ts);\n" +
              "    redis.call('hset', r_key, 'count', capacity - num);\n" +
              "    redis.call('hset', r_key, 'get_ts', ts);\n" +
              "    redis.call('pexpireat', r_key, expire); \n" +
              "    return (capacity - num);\n" +
              " end;" ;

    /**增加令牌数量*/
    private static final String ADD_TOKEN_SCRIP="local num = tonumber(ARGV[1]);\n" +
            "local capacity=tonumber(ARGV[2]);\n"
            +"if redis.call('exists',KEYS[1])=='1' then" +
            "local count = redis.call('hget',KEYS[1], 'count');" +
            "count=count+num;" +
            "if  capacity >0 and count>capacity then count=capacity; end;" +
            "local res=redis.call('hset',KEYS[1], 'count',count);" +
            " return res;" +
            "else  return -1; end;";
    /**漏桶redis脚本ARGV[1]:num ARGV[2]:interval,ARGV[3]:rate,ARGV[4]:capacity ARGV[5]:compareTime ARGV[6]:ts ARGV[7]:expire**/
    private static final String CONSUME_LEAKY_SCRIP="local r_key = tostring(KEYS[1]);\n" +
            "local num = tonumber(ARGV[1]);\n" +
            "local interval = tonumber(ARGV[2]);\n" +
            "local rate = tonumber(ARGV[3]);\n" +
            "local capacity = tonumber(ARGV[4]);\n" +
            "local compareTime = tonumber(ARGV[5]);\n" +
            "local ts = tonumber(ARGV[6]);\n" +
            "local expire = tonumber(ARGV[7]);\n" +
            "local haskey=redis.call('exists', r_key);\n" +
            "if tonumber(haskey)==1 then \n" +
            "   local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
            "   local count = redis.call('hget', r_key, 'count');\n" +
            "   local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
            "   if -factor > 0 then\n" +
            "       local factor = 0 \n" +
            "   end;\n" +
            "   local de_count = factor * rate;\n" +
            "   if de_count > 0 then\n" +
            "       count = count - de_count;\n" +
            "       if -count>0 then\n" +
            "           count=0;\n" +
            "       end;\n" +
            "       redis.call('hset', r_key, 'count', count);\n" +
            "       redis.call('hset', r_key, 'gen_ts', ts);\n" +
            "       redis.call('pexpireat', r_key, expire);\n" +
            "   end;\n" +
            "    local total=(count+num);\n" +
            "   if capacity-total >-1  then\n" +
            "       redis.call('hset', r_key, 'count', total);\n" +
            "       redis.call('hset', r_key, 'get_ts', ts);\n" +
            "       return (capacity - total);\n" +
            "   else \n" +
            "       return -1;\n" +
            "   end;\n" +
            "else \n" +
            "   redis.call('hset', r_key, 'gen_ts', ts);\n" +
            "   redis.call('hset', r_key, 'count', num);\n" +
            "   redis.call('hset', r_key, 'get_ts', ts);\n" +
            "   redis.call('pexpireat', r_key, expire); \n" +
            "   return (capacity - num);\n" +
            "end;";
    /**减少漏桶数量*/
    private static final String DEDUCE_LEAKY_SCRIP="local num = tonumber(ARGV[1]);\n" +
            "local capacity=tonumber(ARGV[2]);\n" +
            "if redis.call('exists',KEYS[1])=='1' then" +
            "local count = redis.call('hget',KEYS[1], 'count');" +
            "count=count-num;" +
            "if  count <0 then count=0; end;" +
            "local res=redis.call('hset',KEYS[1], 'count',count);" +
            " return capacity-res;" +
            "else  return -1; end;";


    /**
     * 消费指定数量的key
     *
     * @param key         对应key
     * @param num         消费的数量
     * @param ts          最后一次消费时间
     * @param expire      令牌桶重置时间,多久没操作会重置,毫秒数
     * @param rate        速度,令牌token生成的速度或漏桶流速
     * @param interval    间隔,生成令牌的时间,为0则直接增加单位速度量
     * @param capacity    令牌桶或漏桶的最大容量,-1不限制,
     * @param type        限流类型 0-令牌桶,1-漏桶
     * @param compareTime 时间间隔比较时间,为0时默认当前时间
     * @return
     */
    @Override
    public Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime) {
        long millis = System.currentTimeMillis();
        if (compareTime < 1) {
            compareTime = millis;
        }
        if(ts<1){
            ts=millis;
        }
        DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
        if(type==0){
            redisScript.setScriptText(CONSUME_TOKEN_SCRIP);
        }else{
            redisScript.setScriptText(CONSUME_LEAKY_SCRIP);
        }
        redisScript.setResultType(Object.class);

        LongStringSerialize java = new LongStringSerialize();
        GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
        List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,interval,rate,capacity,compareTime,ts,expire);
        return Long.valueOf(list.get(0).toString());
    }
    /**
     * 查询当前容量,不涉及加减数量
     *
     * @param key 对应key
     * @return 如果不存在返回, 或数量为0均返回0, 其他 大于0
     */
    @Override
    public Long getCurrentCount(String key) {
        Object count = redisTemplate.opsForHash().get(key, "count");
        if(count!=null){
            return Long.valueOf(count.toString());
        }
        return 0L;
    }

    /**
     * 增加容量
     *
     * @param key      对应key
     * @param num      增加的令牌数量或减少漏桶中的数量
     * @param capacity 容量限制  对于令牌桶 小于1 不限制,大于0 限制,对于漏桶 为剩余最小值,建议为0
     * @param type     限流类型 0-令牌桶,1-漏桶
     * @return 如果key 过期则返回-1
     */
    @Override
    public Long addCount(String key, long num, long capacity, int type) {
        DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
        if(type==0){
            redisScript.setScriptText(ADD_TOKEN_SCRIP);
        }else{
            redisScript.setScriptText(DEDUCE_LEAKY_SCRIP);
        }

        redisScript.setResultType(Object.class);
        LongStringSerialize java = new LongStringSerialize();
        GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
        List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,capacity);
        return Long.valueOf(list.get(0).toString());
    }

}

6.编写http获取request工具类

package com.around.common.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.*;

/**
 * @program:
 * @description: Http工具类
 * @author: moodincode
 * @create: 2020/9/28
 **/
public class WebRequestUtil {
    private final static Logger log = LoggerFactory.getLogger(WebRequestUtil.class);
    private static String localIp;

    public static HttpServletRequest getRequest() {
        if (RequestContextHolder.getRequestAttributes() != null) {
            return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        } else {
            return null;
        }
    }

    public static HttpServletResponse getResponse() {
        if (RequestContextHolder.getRequestAttributes() != null) {
            return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getResponse();
        } else {
            return null;
        }
    }

    private static final String UNKNOWN = "unknown";

    private static final String SYMBOL = ",";

    /**
     * 获取http request body中的参数
     *
     * @param request
     * @return
     */
    public static String readRequestBodyParams(HttpServletRequest request) {
        BufferedReader br = null;
        StringBuilder sb = new StringBuilder("");
        try {
            br = request.getReader();
            String str;
            while ((str = br.readLine()) != null) {
                sb.append(str);
            }
            br.close();
        } catch (IOException e) {
            log.error("获取body参数失败", e);
        } finally {
            if (null != br) {
                try {
                    br.close();
                } catch (IOException e) {
                    log.error("获取body参数失败", e);
                }
            }
        }
        return sb.toString().replaceAll("\r|\n|\t", "");
    }

    public static String getIpAddress() {
        return getIpAddress(getRequest());
    }

    /**
     * 获取IP地址
     *
     * @param request
     * @return
     */
    public static String getIpAddress(HttpServletRequest request) {
        //没有请求则是本机IP
        if (request == null) {
            return getLocalIpAddress();
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        // 如果是多级代理,那么取第一个ip为客户端ip
        if (ip != null && ip.indexOf(SYMBOL) != -1) {
            ip = ip.substring(0, ip.indexOf(SYMBOL)).trim();
        }

        return ip;
    }

    /**
     * 获取cookies
     *
     * @param request
     * @param key
     * @return
     */
    public static String getCookie(HttpServletRequest request, String key) {
        String token = null;
        Cookie[] cookies = request.getCookies();
        if (cookies != null) {
            for (int i = 0; i < cookies.length; i++) {
                if (key.equals(cookies[i].getName())) {
                    token = cookies[i].getValue();
                    break;
                }
            }
        }
        return token;
    }

    /**
     * 获取请求头信息
     *
     * @param request
     * @return
     */
    public static List<String> getHeaders(HttpServletRequest request) {
        List<String> headList = new ArrayList<>();
        Enumeration<String> headers = request.getHeaderNames();
        while (headers.hasMoreElements()) {
            String headName = headers.nextElement();
            headList.add(String.format("%s:%s", headName, request.getHeader(headName)));
        }
        return headList;
    }

    /**
     * 获取头部参数
     *
     * @param request
     * @return
     */
    public static Map<String, String> getHeaderMap(HttpServletRequest request) {
        Map<String, String> headList = new HashMap<>();
        Enumeration<String> headers = request.getHeaderNames();
        while (headers.hasMoreElements()) {
            String headName = headers.nextElement();
            headList.put(headName, request.getHeader(headName));
        }
        return headList;
    }

    /**
     * 获取本机IP地址
     *
     * @return
     * @throws Exception
     */
    public static String getLocalIpAddress() {
        if (localIp != null) {
            return localIp;
        }
        try {
            InetAddress candidateAddress = null;
            // 遍历所有的网络接口
            for (Enumeration ifaces = NetworkInterface.getNetworkInterfaces(); ifaces.hasMoreElements(); ) {
                NetworkInterface iface = (NetworkInterface) ifaces.nextElement();
                // 在所有的接口下再遍历IP
                for (Enumeration inetAddrs = iface.getInetAddresses(); inetAddrs.hasMoreElements(); ) {
                    InetAddress inetAddr = (InetAddress) inetAddrs.nextElement();
                    // 排除loopback类型地址
                    if (!inetAddr.isLoopbackAddress()) {
                        if (inetAddr.isSiteLocalAddress()) {
                            // 如果是site-local地址,就是它了
                            localIp= inetAddr.getHostAddress();
                            return localIp;
                        } else if (candidateAddress == null) {
                            // site-local类型的地址未被发现,先记录候选地址
                            candidateAddress = inetAddr;
                        }
                    }
                }
            }
            if (candidateAddress != null) {
                localIp= candidateAddress.getHostAddress();
                return localIp;
            }
            // 如果没有发现 non-loopback地址.只能用最次选的方案
            InetAddress jdkSuppliedAddress = InetAddress.getLocalHost();
            localIp= jdkSuppliedAddress.getHostAddress();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return localIp;
    }

    /**
     * 获取body内容,注意,request对象需要重写,因为request.getReader()只能用一次
     * @param request
     * @return
     */
    public static String getBodyString(HttpServletRequest request){
        //字符串读取
        try {
            BufferedReader br = request.getReader();
            String str, wholeStr = "";
            while((str = br.readLine()) != null){
                wholeStr += str;
            }
            return wholeStr;
        } catch (IOException e) {
           log.error("解析body参数失败,原因:",e);
        }
        return null;
    }
}

7.编写限流单位计算工具

package com.around.common.utils.limitrate;

/**
 * @description: 限流单位计算
 * @author: moodincode
 * @create 2021/1/08
 */
public class LimitTime {
    private Long millis;
    private Long expire;
    private Long intermission;
    private Long compareTime;

    public LimitTime(Long millis, Long expire, Long intermission, Long compareTime) {
        this.millis = millis;
        this.expire = expire;
        this.intermission = intermission;
        this.compareTime = compareTime;
    }

    public Long getMillis() {
        return millis;
    }

    public void setMillis(Long millis) {
        this.millis = millis;
    }

    public Long getExpire() {
        return expire;
    }

    public void setExpire(Long expire) {
        this.expire = expire;
    }

    public Long getIntermission() {
        return intermission;
    }

    public void setIntermission(Long intermission) {
        this.intermission = intermission;
    }

    public Long getCompareTime() {
        return compareTime;
    }

    public void setCompareTime(Long compareTime) {
        this.compareTime = compareTime;
    }
    /**
     * 计算时间范围
     * @param millis
     * @param interval
     * @param unit
     * @return  时间单位对象
     */
    public static LimitTime calculateTime(long millis, Long interval,LimitUnitEnum unit) {
        LimitTime time=new LimitTime(millis,-1L,interval,millis);
        Long expireTs = unit.getExpireTs(millis);
        time.setExpire(expireTs);
        if(interval<1000L){
            time.setIntermission(expireTs-millis);
        }

        return time;
    }
}

8.编写限流工具类

package com.around.common.utils.limitrate;

import com.around.common.utils.WebRequestUtil;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;

/**
 * @description: 限流工具类
 * @author: moodincode
 * @create 2021/1/8
 */
@Component
public class LimitRateUtil {
    /**机器ID限流**/
    public static String MACHINE_ID="";
    /**全局限流*/
    public static String GLOBAL="limit:00:";
    static {
        String localIpAddress = WebRequestUtil.getLocalIpAddress();
        if(!StringUtils.isEmpty(localIpAddress)){
            MACHINE_ID+=localIpAddress.replace(".","_").replace(":","_");
        }
            MACHINE_ID+="_"+System.currentTimeMillis();
        MACHINE_ID="limit:m"+Math.abs(MACHINE_ID.hashCode())+":";
    }
    @Resource
    private LimitRateProvider limitRateProvider;

    /**
     *
     * @param request
     * @param key
     * @param rate
     * @param interval 小于1000毫秒则采用单位时间
     * @param unit
     * @param type
     * @param capacity
     * @return
     */
    public Long consumeCount(HttpServletRequest request,String key, Long rate, Long interval, LimitUnitEnum unit, LimitTypeEnum type, Long capacity){
        //获取前缀
        String keySuffix=getKeySuffix(type,request);
        long millis = System.currentTimeMillis();
        //计算时间单位
        LimitTime time=LimitTime.calculateTime(millis,interval,unit);
        return consumeCount(keySuffix + key,1L, rate, type, capacity, time);
    }

    /**
     * 消耗次数
     * @param key
     * @param num
     * @param rate
     * @param type
     * @param capacity
     * @param time
     * @return
     */
    public Long consumeCount(String key,Long num, Long rate, LimitTypeEnum type, Long capacity, LimitTime time) {
        if(capacity<1){
            capacity=rate;
        }
        Long count =limitRateProvider.consumeCount(key,num, time.getMillis(), time.getExpire(), rate, time.getIntermission(), capacity, type.getCode(), time.getCompareTime());
        return count;
    }


    /**
     * 获取前缀
     * @param type
     * @param request
     * @return
     */
    public String getKeySuffix(LimitTypeEnum type, HttpServletRequest request) {
        if(LimitTypeEnum.GLOBAL_LEAKY.equals(type)||LimitTypeEnum.GLOBAL_TOKEN.equals(type)){
            return GLOBAL;
        }else if(LimitTypeEnum.APP_LEAKY.equals(type)||LimitTypeEnum.APP_TOKEN.equals(type)){
            return MACHINE_ID;
        }else {
            String ipAddress = WebRequestUtil.getIpAddress(request);
            return "limit:i"+ipAddress.replace(".","_").replace(":","_")+":";
        }
    }

    /**
     * 添加count 用于回滚
     * @param key
     * @param num
     * @param capacity
     * @param type
     * @return
     */
    public Long addCount(String key, Long num, Long capacity, LimitTypeEnum type){
        return limitRateProvider.addCount(key,num,capacity, type.getCode());
    }
    public static void main(String[] args) {
        System.out.println(MACHINE_ID);
    }

}

9.编写注解类

package com.around.common.utils.limitrate;

import java.lang.annotation.*;

/**
 * @description: 限流工具类
 * @author: moodincode
 * @create 2021/1/8
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(LimitRates.class)
public @interface LimitRate {

    /**
     * redis key名称前缀,默认为类名+方法的hash key名作为缓存名称前缀
     * @return
     */
    String name() default "";

    /**
     *  key 支持EL表达式,可操作入参的对象
     * @return
     */
    String key() default "";

    /**
     * 被限制的提示信息
     * @return
     */
    String msg() default "接口被限制访问,请稍后再试";

    /**
     * 令牌从产生的速度,指定单位生成的令牌数量或漏桶的水流速度,默认熟读为5
     * @return
     */
    long rate() default 200;

    /**
     * 令牌或漏桶产生的单位间隔
     * @return
     */
    long interval() default 1;

    /**
     *
     * @return
     */
    LimitUnitEnum unit() default LimitUnitEnum.PER_SECOND;


    /**
     * 限流类型
     * @return
     */
    LimitTypeEnum type() default LimitTypeEnum.IP_TOKEN;

    /**
     * 桶的最大容量,-1不限制,0-默认速度的容量,其他为指定的容量
     * @return
     */
    long capacity() default 0;

    /**
     * 多个注解 按顺序从大到小排序
     * 排序
     * @return
     */
    int order() default 0;

    /**
     * 是否异常回滚,回退次数
     * @return
     */
    boolean rbEx() default false;
}
package com.around.common.utils.limitrate;

import java.lang.annotation.*;

/**
 * @description: 允许多个
 * @author: moodincode
 * @create 2021/1/12
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface LimitRates {
    LimitRate[] value();
}

10.编写AOP拦截器

package com.around.common.utils.limitrate;
import com.around.common.utils.WebRequestUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.text.MessageFormat;
import java.util.*;

/**
 * @description: 限流aop
 * @author: moodincode
 * @create: 2021/01/12 @ param @Order 设置优先级为较低,避免执行顺序问题
 **/
@Aspect
@Component
@Order(999)
public class LimitRateCheckAop {
    private static final Logger log= LoggerFactory.getLogger(LimitRateCheckAop.class);
    @Resource
    private LimitRateUtil limitRateUtil;
    /**
     * @param point
     */

    @Around("@annotation(LimitRate)")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        log.trace("enter LimitRateCheckAop");
        MethodSignature methodSignature = (MethodSignature)  point.getSignature();
        // 获取注解中的内容
        LimitRate[] limitRates = methodSignature.getMethod().getAnnotationsByType(LimitRate.class);
        Arrays.sort(limitRates, Comparator.comparing(LimitRate::order).reversed());
        HttpServletRequest request = WebRequestUtil.getRequest();
        Map<String,LimitRate>  rollbackKey=new HashMap<>();

        for (LimitRate limitRate : limitRates) {
            String key = buildRedisKey(point, methodSignature, limitRate);
            if(limitRate.rbEx()){
                String keySuffix = limitRateUtil.getKeySuffix(limitRate.type(), request);
                key=keySuffix+key;
                long millis = System.currentTimeMillis();
                //计算时间单位
                LimitTime time=LimitTime.calculateTime(millis,limitRate.interval(),limitRate.unit());
                Long count = limitRateUtil.consumeCount(key, 1L, limitRate.rate(), limitRate.type(), limitRate.capacity(), time);
                if(count<0){
                    //抛出日常
                    throw new RuntimeException(limitRate.msg());
                }else{
                    //记录要回滚的key
                    rollbackKey.put(key,limitRate);
                }

            }else{
                Long count = limitRateUtil.consumeCount(request, key, limitRate.rate(), limitRate.interval(), limitRate.unit(), limitRate.type(), limitRate.capacity());
                if(count<0){
                    //抛出日常
                    throw new RuntimeException(limitRate.msg());
                }
            }
        }
        try {
            return point.proceed();
        }catch (Exception e){
            //检查是否需要回滚的key
            if(CollectionUtils.isEmpty(rollbackKey)){
                doRollback(rollbackKey);
            }
            throw e;
        }

    }

    /**
     * 回滚key
     * @param rollbackKey
     */
    private void doRollback(Map<String, LimitRate> rollbackKey) {
        for (Map.Entry<String, LimitRate> entry : rollbackKey.entrySet()) {
            //增加会次数
           limitRateUtil.addCount(entry.getKey(),1L,entry.getValue().capacity(), entry.getValue().type());
        }
    }

    /**
     * 构建key
     * @param point
     * @param methodSignature
     * @param limitRate
     * @return
     */
    private String buildRedisKey(ProceedingJoinPoint point, MethodSignature methodSignature,LimitRate limitRate) {
        // 方法名
        String methodName = methodSignature.getName();
        // 类名
        String className = methodSignature.getDeclaringTypeName();
        // 目标类、方法
        log.debug("类名{}方法名{}", className, methodName);
        String keyPrefix;
        //如果没有指定key前缀,则使用类名+方法名的hash值
        if(StringUtils.isEmpty(limitRate.name())){
            // 防止内容过长,只取其hash值作为key的部分
            keyPrefix = String.valueOf(Math.abs(MessageFormat.format("{0}:{1}", className, methodName).hashCode()));
        }else{
            keyPrefix=limitRate.name();
        }

        String paramKey = "";
        Object[] args = point.getArgs();
      if(!StringUtils.isEmpty(limitRate.key())){
            //构建EL表达式的key
            paramKey = buildExpressKey(methodSignature, limitRate.key(), args);
        }
        return keyPrefix+paramKey;
    }

    /**
     * 使用el表达式解析key
     * @param methodSignature
     * @param el
     * @param args
     * @return
     */
    private String buildExpressKey(MethodSignature methodSignature, String el, Object[] args) {
        String paramKey="";

        ExpressionParser expressionParser = new SpelExpressionParser();
        Expression expression = expressionParser.parseExpression(el);
        EvaluationContext context = new StandardEvaluationContext();
        String[] parameterNames = methodSignature.getParameterNames();
        for (int i = 0; i <parameterNames.length ; i++) {
            context.setVariable(parameterNames[i],args[i]);
        }
        try {
            //解析el表达式,将#id等替换为参数值
            paramKey= expression.getValue(context).toString();
            log.debug("限流工具构建key,el表达式解析key成功,el={},解析后值为:{}",el,paramKey);
        }catch (Exception e){
            log.error("参数key={}解析失败,{}{}",el,e,e.getMessage());
        }
        return paramKey;
    }
}

11.在需要拦截的controller类上使用@LimitRate注解进行拦截,具体设置项看注解

项目git地址 https://gitee.com/moodincode/com_around_project/tree/master/common-utils

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

心情加密语

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值