- 定义注解
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface CusRateLimiter {
/**
* 令牌产生的速率
*
* @return
*/
double permitsPerSecond();
/**
* 客户端等待的毫秒数
*
* @return
*/
int permits();
}
2.实现方法拦截
@Aspect
@Component
public class RateLimiterAOP {
private Map<String, RateLimiter> rateLimiterMap = new ConcurrentHashMap<>();
@Pointcut("execution(public * com.example.nginx.controller.*.*(..))")
public void pointCut() {
}
@Around("pointCut()")
public Object around(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
//判断方法上是否有指定注解
MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
Method method = methodSignature.getMethod();
Annotation annotation = method.getDeclaredAnnotation(CusRateLimiter.class);
if (annotation == null) {
return proceedingJoinPoint.proceed();
}
CusRateLimiter cusRateLimiter = (CusRateLimiter) annotation;
//以请求的url为单位, 来生成令牌桶
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
String requestURI = attributes.getRequest().getRequestURI();
RateLimiter rateLimiter = null;
if (rateLimiterMap.containsKey(requestURI)) {
rateLimiter = rateLimiterMap.get(requestURI);
} else {
rateLimiter = RateLimiter.create(cusRateLimiter.permitsPerSecond());
rateLimiterMap.put(requestURI, rateLimiter);
}
//获取令牌
if (rateLimiter.tryAcquire(cusRateLimiter.permits(), TimeUnit.MILLISECONDS)) {
return proceedingJoinPoint.proceed();
}
failback();
return null;
}
/**
* 服务降级方法
*
* @throws IOException
*/
private void failback() throws IOException {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletResponse response = attributes.getResponse();
response.setHeader("Content-type", "text/html;charset=UTF-8");
try (PrintWriter writer = response.getWriter();) {
writer.println("服务器忙, 请稍后重试");
}
}
}