redis执行lua脚本实现springboot接口限流实践
1.项目依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<artifactId>springboot-rabbitmq</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.5.2</version>
</parent>
<dependencies>
<!-- Springboot核心jar包 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<!-- web开发包:包含Tomcat和Springmvc-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-amqp</artifactId>
</dependency>
<!-- actuator -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<!--logback日志包-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.16.22</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.24</version>
</dependency>
<!-- Radis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- redis依赖commons-pool 这个依赖一定要添加 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<!-- mybatis plus 代码生成器 -->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-boot-starter</artifactId>
<version>3.2.0</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-generator</artifactId>
<version>3.2.0</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.redisson</groupId>
<artifactId>redisson</artifactId>
<version>3.6.5</version>
</dependency>
<!--JWT 依赖-->
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.7.22</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.flowable</groupId>
<artifactId>flowable-engine</artifactId>
<version>6.7.2</version>
</dependency>
<!-- Flowable 任务持久化模块 -->
<dependency>
<groupId>org.flowable</groupId>
<artifactId>flowable-spring-boot-starter</artifactId>
<version>6.7.2</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<version>2.4.3</version>
</plugin>
</plugins>
</build>
</project>
2.RateLimiterAspect.java
package com.springboot.aospect;
import com.alibaba.fastjson.JSONObject;
import com.springboot.diyanotion.RateLimiter;
import com.springboot.enums.LimitType;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* spingAop限流处理
*/
@Aspect
@Component
public class RateLimiterAspect {
private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
private RedisTemplate<String, Object> redisTemplate;
private RedisScript<Long> limitScript;
@Autowired
public void setRedisTemplate1(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Autowired
public void setLimitScript(RedisScript<Long> limitScript) {
this.limitScript = limitScript;
}
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
int time = rateLimiter.time();
int count = rateLimiter.count();
//根据类型生成一个键值, ip 或者 类路径+方法名
String combineKey = getCombineKey(rateLimiter, point);
//Collections.singletonList被限定只被分配一个内存空间,也就是只能存放一个元素的内容。
List<String> keys = Collections.singletonList(combineKey);
// 层是redis执行lua脚本, 键值存储数量,redis过期时间是限流时间
Long number = redisTemplate.execute(limitScript, keys, count, time);
if (null == number || number.intValue() > count) {
HttpServletResponse resp = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getResponse();
JSONObject jsonObject = new JSONObject();
jsonObject.put("success", false);
jsonObject.put("msg", "请求过频繁,请稍后再试......");
try {
assert resp != null;
output(resp, jsonObject.toJSONString());
} catch (Exception e) {
log.error("系统异常error,e:{}", e.toString());
}
}
log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), combineKey);
}
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
stringBuffer.append(getIpAddr(getRequest())).append("-");
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
/**
* 获取request
*/
private HttpServletRequest getRequest() {
return getRequestAttributes().getRequest();
}
private ServletRequestAttributes getRequestAttributes() {
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
return (ServletRequestAttributes) attributes;
}
/**
* 获取客户端IP
*
* @param request 请求对象
* @return IP地址
*/
private String getIpAddr(HttpServletRequest request) {
if (request == null) {
return "unknown";
}
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Forwarded-For");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : getMultistageReverseProxyIp(ip);
}
/**
* 从多级反向代理中获得第一个非unknown IP地址
*
* @param ip 获得的IP地址
* @return 第一个非unknown IP地址
*/
private String getMultistageReverseProxyIp(String ip) {
// 多级反向代理检测
if (ip != null && ip.indexOf(",") > 0) {
final String[] ips = ip.trim().split(",");
for (String subIp : ips) {
if (false == isUnknown(subIp)) {
ip = subIp;
break;
}
}
}
return ip;
}
/**
* 检测给定字符串是否为未知,多用于检测HTTP请求相关
*
* @param checkString 被检测的字符串
* @return 是否未知
*/
public static boolean isUnknown(String checkString) {
return "".equals(checkString) || null == checkString || "unknown".equalsIgnoreCase(checkString);
}
public void output(HttpServletResponse response, String msg) throws IOException {
response.setContentType("application/json;charset=UTF-8");
ServletOutputStream outputStream = null;
try {
outputStream = response.getOutputStream();
outputStream.write(msg.getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
e.printStackTrace();
} finally {
outputStream.flush();
outputStream.close();
}
}
}
3.redis配置文件
package com.springboot.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.redisson.Redisson;
import org.redisson.config.Config;
import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport {
@Bean
@SuppressWarnings(value = {"unchecked", "rawtypes"})
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
Jackson2JsonRedisSerializer serializer = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.EXISTING_PROPERTY);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(serializer);
// Hash的key也采用StringRedisSerializer的序列化方式
template.setHashKeySerializer(new StringRedisSerializer());
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
@Bean
public Redisson redisson() {
Config config = new Config();
config.useSingleServer().setAddress("redis://localhost:6379").setPassword("123456").setDatabase(0);
return (Redisson) Redisson.create(config);
}
}
4.自定义注解
package com.springboot.diyanotion;
import com.springboot.enums.LimitType;
import java.lang.annotation.*;
/**
* 限流注解
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter
{
/**
* 限流key
*/
public String key() default "rate_limit";
/**
* 限流时间,单位秒
*/
public int time() default 10;
/**
* 限流次数
*/
public int count() default 5;
/**
* 限流类型
*/
public LimitType limitType() default LimitType.DEFAULT;
}
5.controller
package com.springboot.controller;
import com.springboot.diyanotion.RateLimiter;
import com.springboot.entity.UserInfo;
import com.springboot.enums.LimitType;
import com.springboot.service.LoginService;
import com.springboot.util.Result;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.Map;
@RestController
public class LoginController {
/**
* lua脚本测试接口限流
* 5秒内同一IP只能访问3次
*/
@GetMapping("/apiLimit")
@RateLimiter(key = "apiLimit", count = 3, time = 5, limitType = LimitType.IP)
public Result apiLimit() {
Map<String, Object> map = new HashMap<>();
map.put("test","测试redis限流");
return Result.success(map);
}
}
6.lua脚本(在resource目录下新建lua文件夹,新建limit.lua文件),内容如下
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key);
if current and tonumber(current) > count then
return tonumber(current);
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current);
7.测试结果