使用 Spring Boot 自定义注解和AOP实现基于IP的接口限流和黑白名单
在我们日常开发的项目中为了保证系统的稳定性,很多时候我们需要对系统做限流处理,它可以有效防止恶意请求对系统造成过载。常见的限流方案主要有:
网关限流: NGINX、Zuul 等 API 网关
服务器端限流: 服务端接口限流
令牌桶算法: 通过定期生成令牌放入桶中,请求需要消耗令牌才能通过
熔断机制: Hystrix、Resilience4j 等
本文将详细介绍 Spring Boot 通过自定义注解和 AOP(面向切面编程),实现基于 IP 的限流和黑白名单功能,包括如何使用 Redis 存储限流和黑名单信息。
项目搭建
添加必要的依赖。在 pom.xml 文件中添加以下内容:
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
</dependencies>
配置 application.yml 加入 redis 配置
spring:
#redis
redis:
# 地址
host: 127.0.0.1
# 端口,默认为6379
port: 6379
# 数据库索引
database: 0
# 密码
password: password
# 连接超时时间
timeout: 10s
lettuce:
pool:
# 连接池中的最小空闲连接
min-idle: 0
# 连接池中的最大空闲连接
max-idle: 8
# 连接池的最大数据库连接数
max-active: 8
# #连接池最大阻塞等待时间(使用负值表示没有限制)
max-wait: -1ms
自定义限流注解
创建一个自定义注解 RateLimit :
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimit {
//限制次数
int limit() default 5;
//限制时间 秒
int timeout() default 60;
}
编写限流切面
使用 AOP 实现限流逻辑,并增加 IP 黑白名单判断 , 使用 Redis 来存储和检查请求次数及黑名单信息。
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.util.concurrent.TimeUnit;
@Aspect
@Component
public class RateLimitAspect {
@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
private HttpServletRequest request;
//定义黑名单key前缀
private static final String BLACKLIST_KEY_PREFIX = "blacklist:";
//定义白名单key前缀
private static final String WHITELIST_KEY_PREFIX = "whitelist:";
@Around("@annotation(rateLimit)")
public Object rateLimit(ProceedingJoinPoint joinPoint, RateLimit rateLimit) throws Throwable {
//获取IP
// String ip =request.getRemoteAddr();
/**
* 没有经过代理使用: request.getRemoteAddr();
*经过nginx代理使用: request.getHeader("X-Real-IP");
**/
String ip =IpUtil.getIpAddress(request);
//黑名单则直接异常
if (isBlacklisted(ip)) {
throw new RuntimeException("超出访问限制已加入黑名单,1小时后再访问");
}
//如果是白名单下的不做限制
if (isWhitelisted(ip)) {
return joinPoint.proceed();
}
String key = generateKey(joinPoint, ip);
int limit = rateLimit.limit();
int timeout = rateLimit.timeout();
String countStr = redisTemplate.opsForValue().get(key);
int count = countStr == null ? 0 : Integer.parseInt(countStr);
if (count < limit) {
redisTemplate.opsForValue().set(key, String.valueOf(count + 1), timeout, TimeUnit.SECONDS);
return joinPoint.proceed();
} else {
addToBlacklist(ip);
throw new RuntimeException("超出请求限制IP已被列入黑名单");
}
}
// 判断是否在黑名单列表内
private boolean isBlacklisted(String ip) {
return redisTemplate.hasKey(BLACKLIST_KEY_PREFIX + ip);
}
// 是否在白名单内
private boolean isWhitelisted(String ip) {
return redisTemplate.hasKey(WHITELIST_KEY_PREFIX + ip);
}
// 添加ip到白名单内
private void addToBlacklist(String ip) {
redisTemplate.opsForValue().set(BLACKLIST_KEY_PREFIX + ip, "true", 1, TimeUnit.HOURS);
}
// redis key 拼接
private String generateKey(ProceedingJoinPoint joinPoint, String ip) {
String methodName = joinPoint.getSignature().getName();
String className = joinPoint.getTarget().getClass().getName();
return className + ":" + methodName + ":" + ip;
}
}
/**
* IP工具类
*/
public class IpUtil {
/**
* 获取ip
* @param request 请求
* @return {@link String }
*/
public static String getIpAddress(HttpServletRequest request) {
String ipAddress = null;
try {
ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
if (ipAddress.equals("127.0.0.1")) {
// 根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ipAddress = inet.getHostAddress();
}
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
if (ipAddress != null && ipAddress.length() > 15) { // "***.***.***.***".length()
// = 15
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
} catch (Exception e) {
ipAddress="";
}
// ipAddress = this.getRequest().getRemoteAddr();
return ipAddress;
}
/**
* 获取网关ip
* @param request 请求
* @return {@link String }
*/
public static String getGatwayIpAddress(ServerHttpRequest request) {
HttpHeaders headers = request.getHeaders();
String ip = headers.getFirst("x-forwarded-for");
if (ip != null && ip.length() != 0 && !"unknown".equalsIgnoreCase(ip)) {
// 多次反向代理后会有多个ip值,第一个ip才是真实ip
if (ip.indexOf(",") != -1) {
ip = ip.split(",")[0];
}
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = headers.getFirst("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = headers.getFirst("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = headers.getFirst("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = headers.getFirst("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = headers.getFirst("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddress().getAddress().getHostAddress();
}
return ip;
}
}
Controller中使用限流注解
创建一个简单的限流测试Controller,并在需要限流的方法上使用 @RateLimit 注解:,需要编写异常处理,返回RateLimitAspect异常信息,并以字符串形式返回
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api")
public class TestController {
@RateLimit(limit = 5, timeout = 60)
@GetMapping("/limit")
public String testRateLimit() {
return "Request successful!";
}
}