@Slf4j @Component @Aspect @RequiredArgsConstructor (onConstructor_ = @Autowired ) public class RedisRateLimitAspect { private final static String REDIS_RATE_LIMIT_KEY_PREFIX = "limit:" ; private final StringRedisTemplate stringRedisTemplate; private final RedisScript<Long> limitRedisScript; @Pointcut ( "@annotation(com.rock.demo.annotation.RedisRateLimiter)" ) public void rateLimit() { } @Before ( "rateLimit()" ) public void pointCut(JoinPoint joinPoint) throws IllegalAccessException { MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Method method = signature.getMethod(); // 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解 RedisRateLimiter redisRateLimit = AnnotationUtils.findAnnotation(method, RedisRateLimiter. class ); RedisRateLimiter.LimitType limitType = redisRateLimit.limitType(); if (redisRateLimit != null ) { //获取时间限制 long timeLimitLength = redisRateLimit.timeLimitLength(); //获取时间限制单位 TimeUnit timeLimitLengthUnit = redisRateLimit.timeLimitLengthUnit(); //时间单位最大访问数目 long max = redisRateLimit.max(); String limitKey = "" ; //ip限流则用ip做limitKeyValue,其他的从参数中获取 if (RedisRateLimiter.LimitType.IP.name().equals(limitType.name())) { HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest(); limitKey = IpUtil.getIpAddr(request); } else { Object[] args = joinPoint.getArgs(); // 1.入参是(String id) limitKey = getIdForSingle(joinPoint, args); if (StringUtils.isEmpty(limitKey)) { // 2.入参是DTO,LimitKey注解放到DTO.id上。 limitKey = getIdForParamsDTO(args); } } //redis存储的key;"limit:"${className}"."${methodName}:${limitKey} String storeKey = REDIS_RATE_LIMIT_KEY_PREFIX + method.getDeclaringClass().getSimpleName() + "." + method.getName() + ":" + limitKey; long now = System.currentTimeMillis(); //将2分钟转化为毫秒时间戳,以获得2分钟前时间 long limitTimeLengthMills = timeLimitLengthUnit.toMillis(timeLimitLength); //应该移除的分值区间 long removeScore = now - limitTimeLengthMills; Long r = stringRedisTemplate.execute( limitRedisScript, Lists.newArrayList(storeKey), "" + now, "" + limitTimeLengthMills, //设置key的保存时间,该key在2分钟的允许时间内做zadd操作 "" + removeScore, //移除当前时间2分钟前过期的score "" + max); //当前接口访问上线 if (r != null ) { if (r == 0 ) { log.error( "【{}】在 " + timeLimitLength + formatTimeUnit(timeLimitLengthUnit) + " 内已达到访问上限,当前接口上限 {}" , storeKey, max); throw new RuntimeException( "手速太快了,慢点儿吧~" ); } else { log.info( "【{}】在 " + timeLimitLength + formatTimeUnit(timeLimitLengthUnit) + " 内访问 {} 次" , storeKey, r); } } } } private String formatTimeUnit(TimeUnit timeUnit) { if (timeUnit == TimeUnit.MINUTES) { return "分钟" ; } else if (timeUnit == TimeUnit.SECONDS) { return "秒" ; } else if (timeUnit == TimeUnit.HOURS) { return "小时" ; } return "illegal timeUnit args" ; } private String getIdForSingle(JoinPoint joinPoint, Object[] args) { if (Objects.nonNull(args) && args.length > 0 ) { MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Annotation[][] parameterAnnotations = signature.getMethod().getParameterAnnotations(); if (Objects.isNull(parameterAnnotations)) { return null ; } // 循环判断是否有limitkey for ( int i = 0 ; i < parameterAnnotations.length; i++) { for (Annotation annotation : parameterAnnotations[i]) { if (annotation instanceof LimitKey) { return (String)args[i]; } } } } return null ; } private String getIdForParamsDTO( Object[] args) throws IllegalAccessException { for (Object arg : args) { String id = getIdForDto(arg); if (StringUtils.isNotEmpty(id)){ return id; } } return null ; } private String getIdForDto(Object arg) throws IllegalAccessException { Field[] fields = arg.getClass().getDeclaredFields(); if (Objects.nonNull(fields) && fields.length > 0 ){ for (Field field : fields) { field.setAccessible( true ); if (field.getType().equals(String. class ) && field.isAnnotationPresent(LimitKey. class )) { return (String)field.get(arg); } } } return null ; } } |