SpringBoot整合Redis + Lua脚本实现限流
1. 引入依赖
注意:SpringBoot版本为2.2.6.RELEASE
<dependencies>
<!-- Spring Boot 依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- redis 依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- aop 依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- guava 依赖 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>21.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
</dependencies>
2. 配置Redis
spring:
redis:
host: 127.0.0.1
port: 6379
3. 定义RedisTemplate
@Configuration
public class RedisLimiterHelper {
@Bean
public RedisTemplate<String, Serializable> RedisTemplate(LettuceConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Serializable> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
serializer.setObjectMapper(objectMapper);
template.setDefaultSerializer(serializer);
template.setKeySerializer(serializer);
template.setValueSerializer(serializer);
template.setHashKeySerializer(serializer);
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
}
4. 定义限流枚举类
public enum LimitType {
/**
* 自定义key
*/
CUSTOMER,
/**
* 请求IP
*/
IP
}
5. 定义限流注解
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Limit {
/** 名称 */
String name() default "";
/** key */
String key() default "";
/** key的前缀 */
String prefix() default "";
/** 时间范围(秒)*/
int period();
/** 单位时间内限制次数 */
int count();
/** 限流类型 */
LimitType limitType() default LimitType.CUSTOMER;
}
6. 定义限流拦截器
@Aspect
@Configuration
@Slf4j
public class LimitInterceptor {
public static final String UNKNOW_KEY = "unknown";
private final RedisTemplate<String, Serializable> redisTemplate;
@Autowired
public LimitInterceptor(RedisTemplate<String, Serializable> limitRedisTemplate) {
this.redisTemplate = limitRedisTemplate;
}
@Around("execution(public * *(..)) && @annotation(com.hmds.redisdemo.config.Limit))")
public Object interceptor(ProceedingJoinPoint point) {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Limit limit = method.getAnnotation(Limit.class);
LimitType limitType = limit.limitType();
String name = limit.name();
String key;
int limitPeriod = limit.period();
int limitCount = limit.count();
switch (limitType) {
case IP:
key = getIpAddr();
break;
case CUSTOMER:
key = limit.key();
break;
default:
key = StringUtils.upperCase(method.getName());
break;
}
ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limit.prefix(), key));
try {
String luaScript = buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
// Number count = stringRedisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
Number count = redisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
log.info("Access try count:{}, name:{}, key:{}", count, name, key);
if (count != null && count.intValue() <= limitCount) {
return point.proceed();
} else {
throw new RuntimeException("You have been dragged into the blacklist");
}
}catch (Throwable e) {
if (e instanceof RuntimeException) {
log.error("LimitInterceptor error", e);
throw new RuntimeException(e);
}
throw new RuntimeException("server exception");
}
}
/**
* redis lua限流脚本
* @return
*/
private String buildLuaScript() {
StringBuilder lua = new StringBuilder();
lua.append("local c");
lua.append("\nc = redis.call('get',KEYS[1])");
// 调用不超过最大值,则直接返回
lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
lua.append("\nreturn c;");
lua.append("\nend");
// 执行计算器自加
lua.append("\nc = redis.call('incr',KEYS[1])");
lua.append("\nif tonumber(c) == 1 then");
// 从第一次调用开始限流,设置对应键值的过期
lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");
lua.append("\nend");
lua.append("\nreturn c;");
return lua.toString();
}
/**
* 获取IP地址
* @return
*/
public String getIpAddr() {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOW_KEY.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
}
7. 接口
@RestController
@Slf4j
public class LimiterController {
private static final AtomicInteger TEST_1 = new AtomicInteger();
private static final AtomicInteger TEST_2 = new AtomicInteger();
private static final AtomicInteger TEST_3 = new AtomicInteger();
@Autowired
private RedisTemplate<String, Serializable> redisTemplate;
@Limit(key = "test", period = 10, count = 3)
@GetMapping("/test1")
public int test1() {
return TEST_1.incrementAndGet();
}
@Limit(key = "customer_test", period = 10, count = 3, limitType = LimitType.CUSTOMER)
@GetMapping("/test2")
public int test2() {
return TEST_2.incrementAndGet();
}
@Limit(key = "ip_test", period = 10, count = 3, limitType = LimitType.IP)
@GetMapping("/test3")
public int test3() {
return TEST_3.incrementAndGet();
}
}
8. 测试
10秒内,限制3次,超过3次抛出异常提醒