在实际开发中,有时候我们需要对某些接口进行限流,防止有人恶意攻击或者是因为某些接口自身的原因,比如发短信接口,IO处理的接口。
这里我们通过自定义一个注解,并利用Spring的AOP拦截器功能来实现限流的功能。限流需要用到redis。
代码:
Limit.java
这里我们有两种限流类型,一种是根据接口本身来进行限流,一种是根据ip来进行限流
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
// 资源名称,用于描述接口功能
String name() default "";
// 资源 key
String key() default "";
// key prefix
String prefix() default "";
// 时间的,单位秒
int period() default 60;
// 限制访问次数
int count() default 60;
// 限制类型
LimitType limitType() default LimitType.IP;
/**
* 使用实例:
* 测试限流注解,下面配置说明该接口 60秒内最多只能访问 10次,保存到redis的键名为 limit_test,
* 即 prefix + "_" + key,也可以根据 IP 来限流,需指定limitType = LimitType.IP
*/
// @Limit(key = "test", period = 60, count = 10, name = "resource", prefix = "limit")
// @GetMapping("/test")
// public int testLimiter() {
// return ATOMIC_INTEGER.incrementAndGet();
// }
}
LimitAspect.java
@Aspect
@Component
public class LimitAspect {
private static final Logger logger = LoggerFactory.getLogger(LimitAspect.class);
@Autowired
private RedisTemplate<String, Object> limitRedisTemplate;
@Pointcut("@annotation(com.yfy.annotation.Limit)")
public void pointcut() {
// do nothing
}
@Around("pointcut()")
public Object around(ProceedingJoinPoint point) throws Throwable {
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Limit limitAnnotation = method.getAnnotation(Limit.class);
LimitType limitType = limitAnnotation.limitType();
String name = limitAnnotation.name();
String key;
int limitPeriod = limitAnnotation.period();
int limitCount = limitAnnotation.count();
switch (limitType) {
case IP:
key = IPUtils.getIpAddr(request);
break;
case CUSTOMER:
key = limitAnnotation.key();
break;
default:
key = StringUtils.upperCase(method.getName());
}
ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limitAnnotation.prefix() + "_", key + "_" + request.getRequestedSessionId()));
String luaScript = buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
Number count = limitRedisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
logger.info("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, name);
if (count != null && count.intValue() <= limitCount) {
return point.proceed();
} else {
throw new LimitAccessException("接口访问超出频率限制");
}
}
/**
* 限流脚本
* 调用的时候不超过阈值,则直接返回并执行计算器自加。
*
* @return lua脚本
*/
private String buildLuaScript() {
return "local c" +
"\nc = redis.call('get',KEYS[1])" +
"\nif c and tonumber(c) > tonumber(ARGV[1]) then" +
"\nreturn c;" +
"\nend" +
"\nc = redis.call('incr',KEYS[1])" +
"\nif tonumber(c) == 1 then" +
"\nredis.call('expire',KEYS[1],ARGV[2])" +
"\nend" +
"\nreturn c;";
}
}
IPUtils.java
public class IPUtils {
private static final String UNKNOWN = "unknown";
protected IPUtils(){
}
/**
* 获取IP地址
* 使用 Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址
* 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非 unknown的有效IP字符串,则为真实IP地址
*/
public static String getIpAddr(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip;
}
}
Report.java
@Data
public class Report implements Serializable {
private Integer status;
private Object data;
private Object notice;
private Object msg;
}
ExceptionTemplet.java
public interface ExceptionTemplet {
Report report();
}
BaseServiceException.java
public abstract class BaseServiceException extends RuntimeException implements ExceptionTemplet {
private Report report;
public BaseServiceException(String message) {
super(message);
}
public BaseServiceException report(Report report){
this.report = report;
return this;
};
@Override
public Report report() {
return report;
}
}
LimitAccessException.java
public class LimitAccessException extends BaseServiceException {
private static final long serialVersionUID = -3608667856397125671L;
public LimitAccessException(String message) {
super(message);
}
@Override
public Report report() {
return ReportFactory.C_1404_CLIENT_REQUEST_DATAERROR.error(getMessage());
}
}
测试:
我们写一个controller接口
@GetMapping("/limit")
@ResponseBody
@Limit(key = "test", period = 60, count = 10, name = "resource", prefix = "limit")
public String testLimit() {
return "success";
}
该接口中的注解表明该接口在60秒内,同一个ip地址最多只能访问10次,如果访问超过10次,则抛出异常。
对于该异常,如果我们希望给用户友好的提示,可以利用Spring的全局异常处理类来对异常进行特殊处理。
Report类为前后端分离中与前端自定义的返回值类
ControllerExceptionAdvice.java
@RestControllerAdvice
public class ControllerExceptionAdvice {
private Logger logger = LoggerFactory.getLogger(ControllerExceptionAdvice.class);
@ExceptionHandler(value = LimitAccessException.class)
@ResponseBody
public Report handle(LimitAccessException e) {
return e.report();
}
}