import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limiter {
double QPS() default 200;
long timeout() default 500;
TimeUnit timeunit() default TimeUnit.MILLISECONDS;
String msg() default "请稍后再试!";
}
import java.io.PrintWriter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.util.concurrent.RateLimiter;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
public class RequestLimitingInterceptor implements HandlerInterceptor {
private final Map<String, RateLimiter> rateLimiterMap = new ConcurrentHashMap<>();
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
JSONObject jsonObject = new JSONObject();
jsonObject.put("code", "msg");
try {
if (handler instanceof HandlerMethod) {
HandlerMethod handlerMethod = (HandlerMethod) handler;
Limiter rateLimit = handlerMethod.getMethodAnnotation(Limiter.class);
if (rateLimit != null) {
String url = request.getRequestURI();
RateLimiter rateLimiter;
if (!rateLimiterMap.containsKey(url)) {
rateLimiter = RateLimiter.create(rateLimit.QPS());
rateLimiterMap.put(url, rateLimiter);
}
rateLimiter = rateLimiterMap.get(url);
boolean acquire = rateLimiter.tryAcquire(rateLimit.timeout(), rateLimit.timeunit());
if (acquire) {
return true;
} else {
log.warn("请求被限流,url:{}", request.getServletPath());
response(response, toJsonObject(jsonObject));
return false;
}
}
}
return true;
} catch (Exception e) {
e.printStackTrace();
response(response, toJsonObject(jsonObject));
return false;
}
}
private void response(HttpServletResponse response, JSONObject jo) {
response.setContentType("application/json; charset=utf-8");
response.setCharacterEncoding("UTF-8");
try (PrintWriter out = response.getWriter()) {
out.append(jo.toJSONString());
} catch (Exception e) {
e.printStackTrace();
}
}
private JSONObject toJsonObject(Object o) {
return JSONObject.parseObject(JSON.toJSONString(o));
}
}
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class WebMvcConfiguration implements WebMvcConfigurer {
@Autowired
protected RequestLimitingInterceptor requestLimitingInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestLimitingInterceptor).addPathPatterns("/**");
}
}
@Limiter(QPS = 0.2, timeout = 100000, timeunit = TimeUnit.MILLISECONDS,msg = "玩命加载中,请稍后再试")