使用spring aop + redis实现ip请求方法防刷
引入依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- aop -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- redis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
redis的配置文件
spring:
redis:
host: localhost
port: 6379
password:
timeout: 5000ms
创建自定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limiter {
/**
* 从第一次访问接口的时间到cycle周期时间内,无法超过frequency次
* @return
*/
int frequency() default 20;
/**
* 周期时间,单位ms:
* 默认周期时间为一分钟
* @return
*/
long cycle() default 60 * 1000;
/**
* 返回的错误信息
*
* @return
*/
String message() default "请求过于频繁";
/**
* 到期时间,单位s:
* 如果在cycle周期时间内超过frequency次,则默认1分钟内无法继续访问
* @return
*/
long expireTime() default 1 * 60;
}
创建aspect
@Aspect
@Component
public class LimitingAspect {
private static final String LIMITING_KEY = "limiting:%s:%s";
private static final String LIMITING_BEGINTIME = "beginTime";
private static final String LIMITING_EXFREQUENCY = "exFrequency";
@Autowired
private RedisTemplate redisTemplate;
@Pointcut("@annotation(com.example.demo.annotation.Limiter)")
public void limitAspect(){
}
@Around("limitAspect()")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes());
HttpServletRequest request = servletRequestAttributes.getRequest();
// 获取IP地址和方法名称
String ipAddr = WebUtil.getIpAddr(request);
String methodName = joinPoint.getSignature().toLongString();
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Limiter limiter = method.getAnnotation(Limiter.class);
long cycle = limiter.cycle(); // 周期
int frequency = limiter.frequency(); // 频率
long currentTime = System.currentTimeMillis(); // 当前时间
/** 使用hash redis中变量名为:limiting:IP:方法 beginTime为键*/
// 获取redis中周期内第一次访问方法的时间和执行的次数
Long beginTime = (Long)redisTemplate.opsForHash().get(String.format(LIMITING_KEY, ipAddr, methodName), LIMITING_BEGINTIME);
Integer exFrequency = (Integer)redisTemplate.opsForHash().get(String.format(LIMITING_KEY, ipAddr, methodName), LIMITING_EXFREQUENCY);
beginTime = (beginTime == null ? 0L : beginTime);
exFrequency = (exFrequency == null ? 0 : exFrequency);
// 如果当前时间减去周期内第一次访问方法的时间大于周期时间,则正常访问
// 并将周期内第一次访问方法的时间和执行次数初始化
if(currentTime - beginTime > cycle) {
redisTemplate.opsForHash().put(String.format(LIMITING_KEY, ipAddr, methodName), LIMITING_BEGINTIME, currentTime);
redisTemplate.opsForHash().put(String.format(LIMITING_KEY, ipAddr, methodName), LIMITING_EXFREQUENCY, 1);
redisTemplate.expire(String.format(LIMITING_KEY, ipAddr, methodName), limiter.expireTime(), TimeUnit.SECONDS);
return joinPoint.proceed();
} else {
// 如果在周期时间内,执行次数小于频率,则正常访问
// 并将执行次数加一
if(exFrequency < frequency) {
redisTemplate.opsForHash().put(String.format(LIMITING_KEY, ipAddr, methodName), LIMITING_EXFREQUENCY, exFrequency + 1);
redisTemplate.expire(String.format(LIMITING_KEY, ipAddr, methodName), limiter.expireTime(), TimeUnit.SECONDS);
return joinPoint.proceed();
} else {
throw new Exception(limiter.message());
}
}
}
}
创建测试类
@RestController
public class TestController {
//限制在周期内只能访问3次
@Limiter(frequency = 3)
@GetMapping("/test")
public String getString(){
return "hello";
}
}
结果
浏览器中访问前三次时都能正常获取返回值,当访问第四次时抛出异常
需要限制访问次数的接口上加上@Limiter注解即可
getIpAddr:
/**
* 获取ip地址
*/
public static String getIpAddr(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
String comma = ",";
String localhost = "127.0.0.1";
if (ip.contains(comma)) {
ip = ip.split(",")[0];
}
if (localhost.equals(ip)) {
// 获取本机真正的ip地址
try {
ip = InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
e.printStackTrace();
}
}
return ip;
}