使用springboot集成redis实现一个简单的限流功能。
实现简单的限流可以通过自定义注解来实现,限流可以分为不同的策略,如针对接口的全局性限流、针对ip的限流,限制1分钟内访问的次数。
实例
限流方式的枚举类
public enum LimitType {
/**
* 默认策略全局限流
*/
DEFAULT,
/**
* 根据请求IP进行限流
*/
IP
}
限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key前缀
*/
String keyPrefix() default "rate_limiter:";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
RedisTemplate类,在 Spring Boot 中,默认的 RedisTemplate 有一个小坑,就是序列化用的是 JdkSerializationRedisSerializer,直接用这个序列化工具将来存到 Redis 上的 key 和 value 都会莫名其妙多一些前缀,这就导致你用命令读取的时候可能会出错,此时只能继续使用 RedisTemplate 将之读取出来。
用redis实现限流会用到lua脚本,使用lua脚本的时候就会出现上面的问题,所有需要修改RedisTempplate的序列化方案。
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
// 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
return redisTemplate;
}
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
其中key、value都使用了jackson序列化方式来解决。redis中的一些原子操作可以借助lua脚本来实现,可以在resources目录下新建lua目录来存放lua脚本文件,内容如下
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
其中redis.call 就是执行具体的 redis 指令。具体执行流程如下:
- 首先获取传进来的key、count、time;
- 通过get获取到这个key对应的值,这个值就是当前时间窗口内可以访问的次数;
- 如果是第一次访问,则拿到的结果为空,否则拿到的结果是一个数字。所以判断若是拿到的结果是一个数字并且这个数字还大于cout则说明已经超过限流限制了,则直接返回查询的结果即可;
- 若是拿到的结果为空,则说明是第一次访问,次数就给当前key自增1,然后设置过期时间,最后把自增1后的值返回即可;
限流切面
@Aspect
@Component
public class RateLimiterAspect {
private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
@Autowired
private RedisTemplate<Object, Object> redisTemplate;
@Autowired
private RedisScript<Long> redisScript;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
String keyPrefix = rateLimiter.keyPrefix();
int time = rateLimiter.time();
int count = rateLimiter.count();
String redisKey = getRedisKey(rateLimiter, point);
List<Object> keys = Collections.singletonList(redisKey);
// try {
Long number = redisTemplate.execute(redisScript, keys, count, time);
if (number == null || number.intValue() > count) {
throw new ServiceException("访问过于频繁,请稍候再试");
}
log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), keyPrefix);
// } catch (Exception e) {
// throw new RuntimeException("服务器限流异常,请稍候再试");
// }
}
public String getRedisKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.keyPrefix());
if (rateLimiter.limitType() == LimitType.IP) {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();
stringBuffer.append(IPUtil.getIpAddress(request)).append("-");
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
}
拦截了所有加了@Ratelimiter注解的方法,在前置通知中对注解进行处理。
全局异常处理类
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(ServiceException.class)
public Map<String,Object> globalException(ServiceException e) {
HashMap<String, Object> map = new HashMap<>();
map.put("status", 500);
map.put("message", e.getMessage());
return map;
}
@ExceptionHandler(value = Exception.class)
public Map<String,Object> jsonCommonErrorHandler(HttpServletRequest req, Exception e) {
HashMap<String, Object> map = new HashMap<>();
map.put("status", 500);
map.put("message", e.getMessage());
return map;
}
}
测试
@GetMapping("/hello")
@RateLimiter(time = 5, count = 3, limitType = LimitType.IP)
public String hello() {
return "hello";
}
@GetMapping("/hello2")
@RateLimiter(time = 5, count = 3, limitType = LimitType.DEFAULT)
public String hello2() {
return "hello2";
}