pom文件引用
<!-- aop切面 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
<version>2.0</version>
</dependency>
lua脚本
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
配置文件
spring:
# redis 配置
redis:
# 地址
host: 192.168.1.7
# 密码
password: 123456
# 端口,默认为6379
port: 6379
# 数据库索引
database: 0
# 连接超时时间
timeout: 10s
lettuce:
pool:
# 连接池中的最小空闲连接
min-idle: 0
# 连接池中的最大空闲连接
max-idle: 8
# 连接池的最大数据库连接数
max-active: 8
# #连接池最大阻塞等待时间(使用负值表示没有限制)
max-wait: -1ms
package org.example.rate;
import java.lang.annotation.*;
/**
* 限流注解
*/
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {
/**
* 限流 redis key
*/
String RATE_LIMIT_KEY = "rate_limit:";
/**
* 限流key
*/
String key() default RATE_LIMIT_KEY;
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
package org.example.rate;
import cn.hutool.core.convert.Convert;
/**
* 类型转换工具类
*/
public class ConvertUtils extends Convert {
}
package org.example.rate;
import org.springframework.http.HttpHeaders;
public class HttpHeadersUtils extends HttpHeaders {
/**
* unknown
*/
public final static String UNKNOWN = "unknown";
/**
* x-forwarded-for
*/
public final static String LOWER_X_FORWARDED_FOR = "x-forwarded-for";
/**
* X-Forwarded-For
*/
public final static String UPPER_X_FORWARDED_FOR = "X-Forwarded-For";
/**
* Proxy-Client-IP
*/
public final static String PROXY_CLIENT_IP = "Proxy-Client-IP";
/**
* WL-Proxy-Client-IP
*/
public final static String WL_PROXY_CLIENT_IP = "WL-Proxy-Client-IP";
/**
* X-Real-IP
*/
public final static String X_REAL_IP = "X-Real-IP";
}
package org.example.rate;
import cn.hutool.core.net.NetUtil;
import javax.servlet.http.HttpServletRequest;
import java.util.Objects;
/**
* 获取IP方法
*/
public class IpUtils extends NetUtil {
public static final String IPV6_LOCAL_IP = "0:0:0:0:0:0:0:1";
/**
* 获取客户端IP
*
* @param request 请求对象
* @return IP地址
*/
public static String getIpAddress(HttpServletRequest request) {
if (Objects.isNull(request)) {
return HttpHeadersUtils.UNKNOWN;
}
String ip = request.getHeader(HttpHeadersUtils.LOWER_X_FORWARDED_FOR);
if (Objects.isNull(ip) || ip.length() == 0 || HttpHeadersUtils.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HttpHeadersUtils.PROXY_CLIENT_IP);
}
if (Objects.isNull(ip) || ip.length() == 0 || HttpHeadersUtils.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HttpHeadersUtils.UPPER_X_FORWARDED_FOR);
}
if (Objects.isNull(ip) || ip.length() == 0 || HttpHeadersUtils.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HttpHeadersUtils.WL_PROXY_CLIENT_IP);
}
if (Objects.isNull(ip) || ip.length() == 0 || HttpHeadersUtils.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HttpHeadersUtils.X_REAL_IP);
}
if (Objects.isNull(ip) || ip.length() == 0 || HttpHeadersUtils.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return IPV6_LOCAL_IP.equals(ip) ? LOCAL_IP : getMultistageReverseProxyIp(ip);
}
/**
* 检查是否为内部IP地址
*
* @param ip IP地址
* @return 结果
*/
public static boolean internalIp(String ip) {
byte[] address = textToNumericFormatV4(ip);
return internalIp(address) || LOCAL_IP.equals(ip);
}
/**
* 检查是否为内部IP地址
*
* @param address byte地址
* @return 结果
*/
private static boolean internalIp(byte[] address) {
if (StringUtils.isNull(address) || address.length < 2) {
return true;
}
final byte b0 = address[0];
final byte b1 = address[1];
// 10.x.x.x/8
final byte SECTION_1 = 0x0A;
// 172.16.x.x/12
final byte SECTION_2 = (byte) 0xAC;
final byte SECTION_3 = (byte) 0x10;
final byte SECTION_4 = (byte) 0x1F;
// 192.168.x.x/16
final byte SECTION_5 = (byte) 0xC0;
final byte SECTION_6 = (byte) 0xA8;
switch (b0) {
case SECTION_1:
return true;
case SECTION_2:
if (b1 >= SECTION_3 && b1 <= SECTION_4) {
return true;
}
case SECTION_5:
if (b1 == SECTION_6) {
return true;
}
default:
return false;
}
}
/**
* 将IPv4地址转换成字节
*
* @param text IPv4地址
* @return byte 字节
*/
public static byte[] textToNumericFormatV4(String text) {
if (text.length() == 0) {
return null;
}
byte[] bytes = new byte[4];
String[] elements = text.split("\\.", -1);
try {
long l;
int i;
switch (elements.length) {
case 1:
l = Long.parseLong(elements[0]);
if ((l < 0L) || (l > 4294967295L)) {
return null;
}
bytes[0] = (byte) (int) (l >> 24 & 0xFF);
bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 2:
l = Integer.parseInt(elements[0]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[0] = (byte) (int) (l & 0xFF);
l = Integer.parseInt(elements[1]);
if ((l < 0L) || (l > 16777215L)) {
return null;
}
bytes[1] = (byte) (int) (l >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 3:
for (i = 0; i < 2; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
l = Integer.parseInt(elements[2]);
if ((l < 0L) || (l > 65535L)) {
return null;
}
bytes[2] = (byte) (int) (l >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 4:
for (i = 0; i < 4; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
break;
default:
return null;
}
} catch (NumberFormatException e) {
return null;
}
return bytes;
}
}
package org.example.rate;
/**
* 限流类型
*/
public enum LimitType {
/**
* 默认策略全局限流
*/
DEFAULT,
/**
* 根据请求者IP进行限流
*/
IP
}
package org.example.rate;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
*
* 限流处理
*
*/
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
String key = rateLimiter.key();
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, point);
List<Object> keys = Collections.singletonList(combineKey);
try {
DefaultRedisScript<Long> limitScript = new DefaultRedisScript<>();
limitScript.setResultType(Long.class);
limitScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
Long number = RedisUtils.execute(limitScript, keys, count, time);
if (Objects.isNull(number) || number.intValue() > count) {
log.info("访问过于频繁,请稍候再试");
throw new ServiceException("访问过于频繁,请稍候再试");
}
log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
} catch (ServiceException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuilder stringBuffer = new StringBuilder(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
stringBuffer.append(IpUtils.getIpAddress(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append(StringUtils.TRANSVERSE);
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append(StringUtils.TRANSVERSE).append(method.getName());
return stringBuffer.toString();
}
}
package org.example.rate;
import cn.hutool.core.lang.TypeReference;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.data.redis.core.script.RedisScript;
import java.util.List;
import java.util.concurrent.TimeUnit;
public class RedisUtils {
/**
* redisTemplate名称
*/
public static final String REDIS_TEMPLATE_NAME = "redisTemplate";
private final static RedisTemplate<Object, Object> redisTemplate;
static {
redisTemplate = ConvertUtils.convert(new TypeReference<RedisTemplate<Object, Object>>() {},
SpringUtils.getBean(REDIS_TEMPLATE_NAME, RedisTemplate.class));
}
/**
* 缓存基本的对象,Integer、String、实体类等
*
* @param key 缓存的键值
* @param value 缓存的值
*/
public static <T> void setCacheObject(final String key, final T value) {
redisTemplate.opsForValue().set(key, value);
}
/**
* 缓存基本的对象,Integer、String、实体类等
*
* @param key 缓存的键值
* @param value 缓存的值
* @param timeout 时间
* @param timeUnit 时间颗粒度
*/
public static <T> void setCacheObject(final String key, final T value, final Integer timeout, final TimeUnit timeUnit) {
redisTemplate.opsForValue().set(key, value, timeout, timeUnit);
}
/**
* 获得缓存的基本对象。
*
* @param key 缓存键值
* @return 缓存键值对应的数据
*/
public static <T> T getCacheObject(final String key) {
ValueOperations<String, T> operation = ConvertUtils.convert(new TypeReference<ValueOperations<String, T>>() {},
redisTemplate.opsForValue());
return operation.get(key);
}
/**
* 删除单个对象
*
* @param key 缓存键值
*/
public static boolean deleteObject(final String key) {
return Boolean.TRUE.equals(redisTemplate.delete(key));
}
/**
* redis 原生脚本执行
*
* @param script RedisScript脚本
* @param keys 缓存键值
* @param args 参数
*/
public static <T> T execute(RedisScript<T> script, List<Object> keys, Object... args) {
return redisTemplate.execute(script, keys, args);
}
}
package org.example.rate;
/**
* 业务异常
*/
public final class ServiceException extends RuntimeException {
private static final long serialVersionUID = 1L;
/**
* 错误码
*/
private Integer code;
/**
* 错误提示
*/
private String message;
/**
* 空构造方法,避免反序列化问题
*/
public ServiceException() {
}
public ServiceException(String message) {
this.message = message;
}
public ServiceException(Integer code, String message) {
this.code = code;
this.message = message;
}
@Override
public String getMessage() {
return message;
}
public Integer getCode() {
return code;
}
public ServiceException setMessage(String message) {
this.message = message;
return this;
}
}
package org.example.rate;
import org.springframework.aop.framework.AopContext;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;
/**
* spring 工具类
*/
@Component
public class SpringUtils implements BeanFactoryPostProcessor, ApplicationContextAware {
/**
* Spring应用上下文环境
*/
private static ConfigurableListableBeanFactory beanFactory;
private static ApplicationContext applicationContext;
@Override
public void postProcessBeanFactory(@NonNull ConfigurableListableBeanFactory beanFactory) throws BeansException {
SpringUtils.beanFactory = beanFactory;
}
/**
* private static ConfigurableListableBeanFactory beanFactory;
* SpringUtils.applicationContext = applicationContext;
*/
@Override
public void setApplicationContext(@NonNull ApplicationContext applicationContext) throws BeansException {
SpringUtils.applicationContext = applicationContext;
}
/**
* 获取对象
*
* @param name bean名称
* @param clazz 待获取类型
* @return Object 一个以所给名字注册的bean的实例
*/
public static <T> T getBean(String name, Class<T> clazz) {
return beanFactory.getBean(name, clazz);
}
/**
* 获取类型为requiredType的对象
*
* @param clz 类型
* @return Object 一个以所给类型注册的bean的实例
*/
public static <T> T getBean(Class<T> clz) {
return beanFactory.getBean(clz);
}
/**
* 发布一个事件
*
* @param event 事件
*/
public static void publishEvent(Object event) {
applicationContext.publishEvent(event);
}
/**
* 获取当前代理bean
* @param t 待获取bean类型
* @param <T> 泛型
* @return 当前代理bean
*/
public static <T> T getCurrentProxyBean(Class<T> t) {
return ConvertUtils.convert(t, AopContext.currentProxy());
}
}
package org.example.rate;
import cn.hutool.core.text.StrFormatter;
import org.springframework.util.AntPathMatcher;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
public class StringUtils extends org.apache.commons.lang3.StringUtils {
/**
* 斜杠
*/
public static final String SLASH = "/";
/**
* 点
*/
public static final String SPOT = ".";
/**
* 逗号
*/
public static final String COMMA = ",";
/**
* 星号
*/
public static final String ASTERISK = "*";
/**
* 与符号
*/
public static final String AMPERSAND = "&";
/**
* 等号
*/
public static final String EQUAL = "=";
/**
* 横杠
*/
public static final String TRANSVERSE = "-";
/**
* 下划线
*/
public static final String SEPARATOR = "_";
/**
* 空格
*/
public static final String SPACE = " ";
/**
* 冒号
*/
public static final String COLON = ":";
/**
* * 判断一个对象数组是否为空
*
* @param objects 要判断的对象数组
* * @return true:为空 false:非空
*/
public static boolean isEmpty(Object[] objects) {
return isNull(objects) || (objects.length == 0);
}
/**
* * 判断一个对象是否为空
*
* @param object Object
* @return true:为空 false:非空
*/
public static boolean isNull(Object object) {
return object == null;
}
/**
* * 判断一个Collection是否为空, 包含List,Set,Queue
*
* @param coll 要判断的Collection
* @return true:为空 false:非空
*/
public static boolean isEmpty(Collection<?> coll) {
return Objects.isNull(coll) || coll.isEmpty();
}
/**
* 查找指定字符串是否匹配指定字符串列表中的任意一个字符串
*
* @param str 指定字符串
* @param characters 需要检查的字符串数组
* @return 是否匹配
*/
public static boolean matches(String str, List<String> characters) {
if (isEmpty(str) || isEmpty(characters)) {
return false;
}
for (String pattern : characters) {
if (isMatch(pattern, str)) {
return true;
}
}
return false;
}
/**
* 判断url是否与规则配置:
* ? 表示单个字符;
* * 表示一层路径内的任意字符串,不可跨层级;
* ** 表示任意层路径;
*
* @param pattern 匹配规则
* @param url 需要匹配的url
* @return 是否与规则配置结果
*/
public static boolean isMatch(String pattern, String url) {
AntPathMatcher matcher = new AntPathMatcher();
return matcher.match(pattern, url);
}
/**
* 格式化文本, {} 表示占位符<br>
* 此方法只是简单将占位符 {} 按照顺序替换为参数<br>
* 如果想输出 {} 使用 \\转义 { 即可,如果想输出 {} 之前的 \ 使用双转义符 \\\\ 即可<br>
* 例:<br>
* 通常使用:format("this is {} for {}", "a", "b") -> this is a for b<br>
* 转义{}: format("this is \\{} for {}", "a", "b") -> this is \{} for a<br>
* 转义\: format("this is \\\\{} for {}", "a", "b") -> this is \a for b<br>
*
* @param template 文本模板,被替换的部分用 {} 表示
* @param params 参数值
* @return 格式化后的文本
*/
public static String format(String template, Object... params) {
if (isEmpty(params) || isEmpty(template)) {
return template;
}
return StrFormatter.format(template, params);
}
/**
* 驼峰转下划线命名
*/
public static String toUnderScoreCase(String str) {
if (Objects.isNull(str)) {
return null;
}
StringBuilder sb = new StringBuilder();
// 前置字符是否大写
boolean preCharIsUpperCase;
// 当前字符是否大写
boolean currentCharIsUpperCase;
// 下一字符是否大写
boolean nextCharIsUpperCase = true;
for (int i = 0; i < str.length(); i++) {
char c = str.charAt(i);
if (i > 0) {
preCharIsUpperCase = Character.isUpperCase(str.charAt(i - 1));
} else {
preCharIsUpperCase = false;
}
currentCharIsUpperCase = Character.isUpperCase(c);
if (i < (str.length() - 1)) {
nextCharIsUpperCase = Character.isUpperCase(str.charAt(i + 1));
}
if (preCharIsUpperCase && currentCharIsUpperCase && !nextCharIsUpperCase) {
sb.append(SEPARATOR);
} else if ((i != 0 && !preCharIsUpperCase) && currentCharIsUpperCase) {
sb.append(SEPARATOR);
}
sb.append(Character.toLowerCase(c));
}
return sb.toString();
}
}
最后的效果