import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.concurrent.TimeUnit;
@Inherited
@Documented
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
double permitsPerSecond() default 10;
long timeout() default 0;
TimeUnit timeUnit() default TimeUnit.MILLISECONDS;
}
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.RateLimiter;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.annotations.Results;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
@Scope
@Aspect
@Component
@Slf4j
public class RateLimitAspect {
private Map<String, RateLimiter> limitMap = Maps.newConcurrentMap();
@Pointcut("@annotation(com.xxx.RateLimit)")
private void pointcut() {
}
@Around(value = "pointcut()")
public Object around(ProceedingJoinPoint joinPoint) {
Object obj = null;
try {
RateLimit rateLimit = ((MethodSignature) joinPoint.getSignature()).getMethod()
.getAnnotation(RateLimit.class);
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes())
.getRequest();
String url = request.getRequestURI();
if (rateLimit != null) {
double permitsPerSecond = rateLimit.permitsPerSecond();
long timeout = rateLimit.timeout();
TimeUnit timeUnit = rateLimit.timeUnit();
RateLimiter rateLimiter = null;
if (!limitMap.containsKey(url)) {
rateLimiter = RateLimiter.create(permitsPerSecond);
limitMap.put(url, rateLimiter);
log.warn("请求======>" + url + "创建令牌桶,容量为:" + permitsPerSecond);
}
rateLimiter = limitMap.get(url);
if (rateLimiter.tryAcquire(timeout, timeUnit)) {
obj = joinPoint.proceed();
} else {
obj = Objects.toString("请求太过频繁,请稍后重试");
log.warn("请求太过频繁,请稍后重试");
}
}
} catch (Throwable e) {
e.printStackTrace();
}
return obj;
}
}