基于Spring Aop + Redis实现分布式多维度前置后置限流
说明
在实际场景,比如发送短信验证码、刷评论是需要一定限流控制的,其中限流又可以分为前置限流,后置限流。
所谓前置限流即为调用目标接口前校验,无论被调用的接口是否发生异常或者是否返回预期值;
后置限流是调用接口后,可以根据指定的Condition判断是否记录次数,Condition支持EL表达式。
本文通过Spring Aop + 自定义注解 + Redis 分布式锁 + Redi lua脚本实现前、后置限流。并且提供用户维度、IP维度和自定义EL表达式Key多维度限流
配置Maven依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
配置Redis
配置文件application.yml
redis:
host: 192.168.0.1
port: 6379
password:
lettuce:
pool:
time-between-eviction-runs: 30s
max-active: 10
max-wait: -1ms
min-idle: 0
max-idle: 8
其中配置time-between-eviction-runs表示每隔多长时间清理redis连接池中的空闲链接到min-idle配置的数量。最好配置,否则使用lettuce时候会经常出现:远程主机强制关闭一个现有的链接...
Confirguration配置类
@Slf4j
@EnableCaching
@ConditionalOnClass(RedisOperations.class)
public class RedisConfiguration extends CachingConfigurerSupport {
/**
* RedisTemplate<String, Object> 走的是RedisTemplate
* RedisTemplate<String, String> 走的是StringRedisTemplate
*/
@Bean
@Primary
public RedisTemplate redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
// 指定序列化输入的类型,保证反序列化出来一个java对象
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
//objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
//防止对象存入后,解析出来时是各Map
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.WRAPPER_ARRAY);
serializer.setObjectMapper(objectMapper);
RedisSerializer<String> stringRedisSerializer = new StringRedisSerializer();
template.setKeySerializer(stringRedisSerializer);
template.setHashKeySerializer(stringRedisSerializer);
template.setValueSerializer(serializer);
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
@Bean
public RateLimitAspect rateLimitAspect(RedisTemplate<String, Object> redisTemplate, RedisLockService redisLock) {
return new RateLimitAspect(redisTemplate, redisLock);
}
@Bean
public RedisLockService redisLockService(RedisTemplate<String, Long> redisTemplate) {
return new RedisLockService(redisTemplate);
}
@Bean
public RedisLockAspect redisLockAspect(RedisLockService redisLockService) {
return new RedisLockAspect(redisLockService);
}
}
定义限流注解@KeyRateLimiter和@PostKeyRateLimiter
前面先创建枚举类型RateLimiterType,即定义支持哪几种限流模式
public enum RateLimiterType {
/**
* 客户端ip
*/
CLIENT_IP,
/**
* 用户
*/
USER,
/**
* 自定义模式,需要指定key
*/
CUSTOM
}
前置限流注解:KeyRateLimiter,其中key支持EL表达式解析,可以获取到目标方法上面的参数作为Key值;另一个type可以指定使用哪种限流维度。
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(KeyRateLimiters.class)
public @interface KeyRateLimiter {
/**
* 限流Key,支持Spring el
*
* @return Key
*/
String key() default "";
/**
* 每秒令牌数
*
* @return 每秒令牌数
*/
int limit() default 1;
/**
* 频率,默认1
*/
int interval() default 1;
/**
* 频率单位,默认秒
*/
TimeUnit intervalUnit() default TimeUnit.SECONDS;
/**
* 限流类型,如果为CUSTOM,需要指定key
*/
RateLimiterType type() default RateLimiterType.CUSTOM;
/**
* 限流拒绝后的消息内容
*/
String message() default "您的操作过快,请稍后再试!";
}
后置限流注解:PostKeyRateLimiter,与KeyRateLimiter不同的是增加了condition,根据el表达式的bool返回值判断是否计入调用次数。另外,PostKeyRateLimiter的实现方式也不一样,由于调用计数是发生在方法执行完成之后,所以需要结合Redis分布式锁来串行化调用,性能自然比会KeyRateLimiter差一些。两者都是使用Redis pipeline,同时在一个方法上面叠加配置。
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(PostKeyRateLimiters.class)
public @interface PostKeyRateLimiter {
/**
* 限流Key,支持Spring el
*
* @return Key
*/
String key() default "";
/**
* 每秒令牌数
*
* @return 每秒令牌数
*/
int limit() default 1;
/**
* 频率,默认1
*/
int interval() default 1;
/**
* 频率单位,默认秒
*/
TimeUnit intervalUnit() default TimeUnit.SECONDS;
/**
* 限流类型,如果为CUSTOM,需要指定key
*/
RateLimiterType type() default RateLimiterType.CUSTOM;
/**
* 生效表达式(包括取返回值#rtv.code == 200)
*/
String condition() default "";
/**
* 限流拒绝后的消息内容
*/
String message() default "您的操作过快,请稍后再试!";
}
再来两个组合注解,支持多个使用限流注解同时使用
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface KeyRateLimiters {
KeyRateLimiter[] value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PostKeyRateLimiters {
PostKeyRateLimiter[] value();
}
创建配置类,上面PostKeyRateLimiter和KeyRateLimiter最终转为RateLimitConfig实例
@Data
public class RateLimitConfig {
/**
* 限流Key
*/
private String key;
/**
* 区间令牌数
*/
private int limit;
/**
* 区间频率
*/
private int rateInterval;
/**
* 频率单位,默认秒
*/
private TimeUnit intervalUnit;
/**
* 限流触发条件,spEL表达式
*/
private String condition;
/**
* 限流类型
*/
private RateLimiterType type;
/**
* 限流拒绝后的消息内容
*/
private String message;
public RateLimitConfig(PostKeyRateLimiter keyRateLimiter) {
this.key = keyRateLimiter.key();
this.limit = keyRateLimiter.limit();
this.rateInterval = keyRateLimiter.interval();
this.intervalUnit = keyRateLimiter.intervalUnit();
this.message = keyRateLimiter.message();
this.condition = keyRateLimiter.condition();
this.type = keyRateLimiter.type();
}
public RateLimitConfig(KeyRateLimiter keyRateLimiter) {
this.key = keyRateLimiter.key();
this.limit = keyRateLimiter.limit();
this.rateInterval = keyRateLimiter.interval();
this.intervalUnit = keyRateLimiter.intervalUnit();
this.message = keyRateLimiter.message();
this.type = keyRateLimiter.type();
}
}
创建限流切面RateLimitAspect
@Slf4j
@Aspect
@RequiredArgsConstructor
public class RateLimitAspect extends AbstractAspect {
private final RedisTemplate<String, Object> redisTemplate;
private final RedisLockService redisLock;
private static RedisScript<Number> rateLuaScript;
static {
// 返回0,1形式
String luaScript = "local current = tonumber(redis.call('get',KEYS[1]) or '0')\n" +
"if current >= tonumber(ARGV[1]) then\n" +
"\treturn 0\n" +
"end\n" +
"current = redis.call('incr',KEYS[1])\n" +
"if current == 1 then\n" +
"\tredis.call('pexpire',KEYS[1],ARGV[2])\n" +
"end\n" +
"return 1";
rateLuaScript = new DefaultRedisScript<>(luaScript, Number.class);
}
/**
* 前置定义切入点
*/
@Pointcut("@annotation(com.iwork.boot.redis.rt.KeyRateLimiter) " +
"|| @annotation(com.iwork.boot.redis.rt.KeyRateLimiters) " +
"|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiter) " +
"|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiters)")
public void frontRateLimiter() {
}
@Around("frontRateLimiter()")
public Object executeFront(ProceedingJoinPoint joinPoint) throws Throwable {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
List<RateLimitConfig> limitConfigs = new ArrayList<>(8);
List<RateLimitConfig> postLimitConfigs = new ArrayList<>(4);
KeyRateLimiter keyRateLimiter = method.getAnnotation(KeyRateLimiter.class);
KeyRateLimiters keyRateLimiters = method.getAnnotation(KeyRateLimiters.class);
PostKeyRateLimiter postKeyRateLimiter = method.getAnnotation(PostKeyRateLimiter.class);
PostKeyRateLimiters postKeyRateLimiters = method.getAnnotation(PostKeyRateLimiters.class);
if (keyRateLimiter != null) {
limitConfigs.add(new RateLimitConfig(keyRateLimiter));
}
if (keyRateLimiters != null && keyRateLimiters.value().length > 0) {
Stream.of(keyRateLimiters.value()).map(RateLimitConfig::new).forEach(limitConfigs::add);
}
if (postKeyRateLimiter != null) {
postLimitConfigs.add(new RateLimitConfig(postKeyRateLimiter));
}
if (postKeyRateLimiters != null && postKeyRateLimiters.value().length > 0) {
Stream.of(postKeyRateLimiters.value()).map(RateLimitConfig::new).forEach(postLimitConfigs::add);
}
// 前置校验
setKey("rt:front:", joinPoint, limitConfigs);
Set<String> errMsgSet = validateFront(limitConfigs.toArray(new RateLimitConfig[]{}));
if (!errMsgSet.isEmpty()) {
// 此处应该抛出特定异常,通过全局异常拦截处理
throw new BusinessException(errMsgSet.toString());
}
// 后置校验需要上锁
if (!postLimitConfigs.isEmpty()) {
// 设置Key
setKey("rt:post:", joinPoint, postLimitConfigs);
String key = "locks:" + postLimitConfigs.iterator().next().getKey();
// 获取锁后执行
return redisLock.executeWithLock(key, 10, 60, TimeUnit.SECONDS, () -> {
SessionCallback<Number> callback = new SessionCallback<Number>() {
@Override
public Number execute(RedisOperations operations) throws DataAccessException {
ValueOperations kvValueOperations = operations.opsForValue();
for (RateLimitConfig postLimitConfig : postLimitConfigs) {
String key1 = postLimitConfig.getKey();
kvValueOperations.get(key1);
}
return null;
}
};
List<Object> objects = redisTemplate.executePipelined(callback);
for (int i = 0; i < postLimitConfigs.size(); i++) {
Number val = (Number) objects.get(i);
RateLimitConfig rateLimitConfig = postLimitConfigs.get(i);
if (val != null && val.longValue() >= rateLimitConfig.getLimit()) {
errMsgSet.add(rateLimitConfig.getMessage());
}
}
if (!errMsgSet.isEmpty()) {
// 此处应该抛出特定异常,通过全局异常拦截处理
throw new BusinessException(errMsgSet.toString());
}
try {
// 执行业务方法
Object proceed = joinPoint.proceed();
// 扣减令牌
RateLimitConfig[] filterConfigs = postLimitConfigs.stream()
.filter(config -> parsePostSpEl(proceed, config))
.collect(Collectors.toList())
.toArray(new RateLimitConfig[]{});
validateFront(filterConfigs);
return proceed;
} catch (BusinessException e) {
throw e;
} catch (Throwable throwable) {
throw new BusinessException(throwable);
}
});
}
return joinPoint.proceed();
}
private Set<String> validateFront(RateLimitConfig... rateLimitConfigs) {
Set<String> errorMsg = new HashSet<>(rateLimitConfigs.length);
List<Object> objects = redisTemplate.executePipelined(new SessionCallback<Number>() {
@Override
public Number execute(RedisOperations operations) throws DataAccessException {
for (RateLimitConfig limitConfig : rateLimitConfigs) {
// 这里不能使用long类型,否则越界 ERR value is not an integer or out of range
int period = (int) limitConfig.getIntervalUnit().toMillis(limitConfig.getRateInterval());
operations.execute(rateLuaScript, Collections.singletonList(limitConfig.getKey()), limitConfig.getLimit(), period);
}
return null;
}
});
for (int i = 0; i < rateLimitConfigs.length; i++) {
Number val = (Number) objects.get(i);
// 被限流
if (val.longValue() == 0L) {
errorMsg.add(rateLimitConfigs[i].getMessage());
}
}
return errorMsg;
}
private void setKey(String prefix, ProceedingJoinPoint joinPoint, List<RateLimitConfig> limitConfigs) {
for (RateLimitConfig limitConfig : limitConfigs) {
String key = limitConfig.getKey();
RateLimiterType type = limitConfig.getType();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
String methodKey = prefix + parseElKey(joinPoint, limitConfig.getKey());
// 基于客户端ip
if (type == RateLimiterType.CLIENT_IP) {
HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
.map(ServletRequestAttributes.class::cast)
.map(ServletRequestAttributes::getRequest)
.orElseThrow(() -> new IllegalStateException("只能在Web环境中获取Request对象!"));
String clientIP = ServletUtil.getClientIP(request);
methodKey = methodKey + ":" + clientIP;
}
// 基于用户维度
else if (type == RateLimiterType.USER) {
String userId = authentication.getPrincipal().toString();
methodKey = methodKey + ":" + userId;
}
// 自定义,key不能为空
else {
Assert.hasText(key, "限流Key不能为空!");
}
limitConfig.setKey(methodKey);
}
}
private boolean parsePostSpEl(Object val, RateLimitConfig limitConfig) {
String condition = limitConfig.getCondition();
if (StringUtils.isBlank(condition) || !condition.contains(EL_PREFIX)) {
return true;
}
StandardEvaluationContext context = new StandardEvaluationContext();
context.setVariable("rtv", val);
Expression expression = expressionParser.parseExpression(condition);
return Optional.ofNullable(expression.getValue(context, Boolean.class)).orElse(true);
}
}
上面使用Redis pipline、Redis Lock,逻辑不难就不做细讲了,有疑问欢迎提问!
父类代码
public abstract class AbstractAspect {
protected static ExpressionParser expressionParser = new SpelExpressionParser();
protected static final String EL_PREFIX = "#";
protected String getMethodKey(ProceedingJoinPoint joinPoint, String elKey) {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
String limitKey = parseElKey(joinPoint, elKey);
String className = signature.getDeclaringType().getSimpleName();
String methodName = signature.getName();
limitKey = "method:" + className + "#" + methodName + "#" + limitKey;
return limitKey;
}
protected String parseElKey(ProceedingJoinPoint joinPoint, String elKey) {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
String[] parameterNames = signature.getParameterNames();
Object[] parameterValues = joinPoint.getArgs();
// 解析el
if (elKey.contains(EL_PREFIX)) {
StandardEvaluationContext context = new StandardEvaluationContext();
for (int i = 0; i < parameterNames.length; i++) {
context.setVariable(parameterNames[i], parameterValues[i]);
}
// 解析
Expression expression = expressionParser.parseExpression(elKey);
return expression.getValue(context, String.class);
}
return elKey;
}
protected Map<String, Object> getMethodParameters(JoinPoint joinPoint) {
Map<String, Object> parameters = new LinkedHashMap<>(18);
if (joinPoint instanceof MethodSignature) {
//参数值
Object[] argValues = joinPoint.getArgs();
if(argValues != null) {
//参数名称
String[] argNames = ((MethodSignature) joinPoint.getSignature()).getParameterNames();
for (int i = 0; i < argNames.length; i++) {
parameters.put(argNames[i], argValues[i]);
}
}
}
return parameters;
}
}
使用
@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
基于IP的前置限流
@KeyRateLimiter(type = RateLimiterType.USER)
基于用户维度的前置限流
@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
自定义Key限流,并且通过返回值code==200才标记有效访问,进行限流
@Slf4j
@RestController
@Api(tags = "系统:系统授权接口")
@RequiredArgsConstructor
public class AuthController {
@AnonymousAccess
@ApiOperation("获取验证码")
@GetMapping(value = "/code")
@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
@KeyRateLimiter(type = RateLimiterType.USER)
@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
public XCloudResponse<Object> getCode(@RequestParam String username) {
// 省略代码细节
}