一:项目结构图
二、项目构成
2.1 pom文件
<?xml version="1.0" encoding="UTF-8"?>
< project xmlns = " http://maven.apache.org/POM/4.0.0" xmlns: xsi= " http://www.w3.org/2001/XMLSchema-instance"
xsi: schemaLocation= " http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd" >
< modelVersion> 4.0.0</ modelVersion>
< parent>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-starter-parent</ artifactId>
< version> 2.6.7</ version>
< relativePath/>
</ parent>
< groupId> org.javaboy.org.javaboy.rate_limiter</ groupId>
< artifactId> org.javaboy.rate_limiter</ artifactId>
< version> 0.0.1-SNAPSHOT</ version>
< name> org.javaboy.rate_limiter</ name>
< description> Demo project for Spring Boot</ description>
< properties>
< java.version> 1.8</ java.version>
</ properties>
< dependencies>
< dependency>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-starter-data-redis</ artifactId>
</ dependency>
< dependency>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-starter-web</ artifactId>
</ dependency>
< dependency>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-starter-aop</ artifactId>
</ dependency>
< dependency>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-starter-test</ artifactId>
< scope> test</ scope>
</ dependency>
</ dependencies>
< build>
< plugins>
< plugin>
< groupId> org.springframework.boot</ groupId>
< artifactId> spring-boot-maven-plugin</ artifactId>
</ plugin>
</ plugins>
</ build>
</ project>
2.2 lua脚本
local key = KEYS[ 1 ]
local time = tonumber ( ARGV[ 1 ] )
local count = tonumber ( ARGV[ 2 ] )
local current = redis. call ( 'get' , key)
if current and tonumber ( current) > count then
return tonumber ( current)
end
current = redis. call ( 'incr' , key)
if tonumber ( current) == 1 then
redis. call ( 'expire' , key, time)
end
return tonumber ( current)
2.3 application.properties文件
server.port=8888
spring.redis.host=124.221.94.111
spring.redis.port=6379
spring.redis.password=123456
2.4 annotation包下-RateLimiter注解
package org. javaboy. rate_limiter. annotation ;
import org. javaboy. rate_limiter. enums. LimitType ;
import java. lang. annotation. ElementType ;
import java. lang. annotation. Retention ;
import java. lang. annotation. RetentionPolicy ;
import java. lang. annotation. Target ;
@Retention ( RetentionPolicy . RUNTIME )
@Target ( ElementType . METHOD )
public @interface RateLimiter {
String key ( ) default "rate_limit" ;
int time ( ) default 60 ;
int count ( ) default 100 ;
LimitType limitType ( ) default LimitType . DEFAULT ;
}
2.5、aspectj包下-RateLimiterAspectj
package org. javaboy. rate_limiter. aspectj ;
import org. aopalliance. intercept. Joinpoint ;
import org. aspectj. lang. JoinPoint ;
import org. aspectj. lang. annotation. Aspect ;
import org. aspectj. lang. annotation. Before ;
import org. aspectj. lang. reflect. MethodSignature ;
import org. javaboy. rate_limiter. annotation. RateLimiter ;
import org. javaboy. rate_limiter. enums. LimitType ;
import org. javaboy. rate_limiter. exception. RateLimitException ;
import org. javaboy. rate_limiter. utils. IpUtils ;
import org. slf4j. Logger ;
import org. slf4j. LoggerFactory ;
import org. springframework. beans. factory. annotation. Autowired ;
import org. springframework. data. redis. core. RedisTemplate ;
import org. springframework. data. redis. core. script. RedisScript ;
import org. springframework. stereotype. Component ;
import org. springframework. web. context. request. RequestContextHolder ;
import org. springframework. web. context. request. ServletRequestAttributes ;
import java. lang. reflect. Method ;
import java. util. Collections ;
@Aspect
@Component
public class RateLimiterAspectj {
private static final Logger logger = LoggerFactory . getLogger ( RateLimiterAspectj . class ) ;
@Autowired
RedisTemplate < Object , Object > redisTemplate;
@Autowired
RedisScript < Long > redisScript;
@Before ( "@annotation(rateLimiter)" )
public void before ( JoinPoint jp , RateLimiter rateLimiter) throws RateLimitException {
int time = rateLimiter. time ( ) ;
int count = rateLimiter. count ( ) ;
String combineKey = getCombineKey ( rateLimiter, jp) ;
try {
Long number = redisTemplate. execute ( redisScript, Collections . singletonList ( combineKey) ,
time, count) ;
if ( number == null || number. intValue ( ) > count) {
logger. info ( "当前接口达到最大限流次数" ) ;
throw new RateLimitException ( "访问太频繁,一会在访问" ) ;
}
logger. info ( "一个时间窗内请求次数:{},当前请求次数:{},缓存的 key 为 {}" , count, number, combineKey) ;
} catch ( Exception e) {
throw e;
}
}
private String getCombineKey ( RateLimiter rateLimiter, JoinPoint jp) {
StringBuffer key = new StringBuffer ( rateLimiter. key ( ) ) ;
if ( rateLimiter. limitType ( ) == LimitType . IP ) {
key. append ( IpUtils . getIpAddr ( (
( ServletRequestAttributes )
RequestContextHolder . getRequestAttributes ( ) ) . getRequest ( ) ) )
. append ( "-" ) ;
}
MethodSignature signature = ( MethodSignature ) jp. getSignature ( ) ;
Method method = signature. getMethod ( ) ;
key. append ( method. getDeclaringClass ( ) . getName ( ) )
. append ( "-" )
. append ( method. getName ( ) ) ;
return key. toString ( ) ;
}
}
2.6、config包下-RedisConfig
package org. javaboy. rate_limiter. config ;
import org. springframework. context. annotation. Bean ;
import org. springframework. context. annotation. Configuration ;
import org. springframework. core. io. ClassPathResource ;
import org. springframework. data. redis. connection. RedisConnectionFactory ;
import org. springframework. data. redis. core. RedisTemplate ;
import org. springframework. data. redis. core. script. DefaultRedisScript ;
import org. springframework. data. redis. serializer. Jackson2JsonRedisSerializer ;
import org. springframework. scripting. support. ResourceScriptSource ;
@Configuration
public class RedisConfig {
@Bean
RedisTemplate < Object , Object > redisTemplate ( RedisConnectionFactory redisConnectionFactory) {
RedisTemplate < Object , Object > template = new RedisTemplate < > ( ) ;
template. setConnectionFactory ( redisConnectionFactory) ;
Jackson2JsonRedisSerializer < Object > serializer = new Jackson2JsonRedisSerializer < Object > ( Object . class ) ;
template. setKeySerializer ( serializer) ;
template. setHashKeySerializer ( serializer) ;
template. setValueSerializer ( serializer) ;
template. setHashValueSerializer ( serializer) ;
return template;
}
@Bean
DefaultRedisScript < Long > limitScript ( ) {
DefaultRedisScript < Long > script = new DefaultRedisScript < > ( ) ;
script. setResultType ( Long . class ) ;
script. setScriptSource ( new ResourceScriptSource ( new ClassPathResource ( "lua/limit.lua" ) ) ) ;
return script;
}
}
2.7、enums包下的-LimitType
package org. javaboy. rate_limiter. enums ;
public enum LimitType {
DEFAULT ,
IP
}
2.8、exception包下定义的异常类
2.8.1 RateLimiterException
package org. javaboy. rate_limiter. exception ;
public class RateLimitException extends Exception {
public RateLimitException ( String message) {
super ( message) ;
}
}
2.8.2 GlobalExceptio
package org. javaboy. rate_limiter. exception ;
import org. springframework. web. bind. annotation. ExceptionHandler ;
import org. springframework. web. bind. annotation. RestControllerAdvice ;
import java. util. HashMap ;
import java. util. Map ;
@RestControllerAdvice
public class GlobalException {
@ExceptionHandler ( RateLimitException . class )
public Map < String , Object > rateLimitException ( RateLimitException e) {
Map < String , Object > map = new HashMap < > ( ) ;
map. put ( "status" , 500 ) ;
map. put ( "message" , e. getStackTrace ( ) ) ;
return map;
}
}
2.9、utils工具类,获取ip
package org. javaboy. rate_limiter. exception ;
package org. javaboy. rate_limiter. utils ;
import javax. servlet. http. HttpServletRequest ;
import java. net. InetAddress ;
import java. net. UnknownHostException ;
public class IpUtils {
public static String getIpAddr ( HttpServletRequest request) {
if ( request == null ) {
return "unknown" ;
}
String ip = request. getHeader ( "x-forwarded-for" ) ;
if ( ip == null || ip. length ( ) == 0 || "unknown" . equalsIgnoreCase ( ip) ) {
ip = request. getHeader ( "Proxy-Client-IP" ) ;
}
if ( ip == null || ip. length ( ) == 0 || "unknown" . equalsIgnoreCase ( ip) ) {
ip = request. getHeader ( "X-Forwarded-For" ) ;
}
if ( ip == null || ip. length ( ) == 0 || "unknown" . equalsIgnoreCase ( ip) ) {
ip = request. getHeader ( "WL-Proxy-Client-IP" ) ;
}
if ( ip == null || ip. length ( ) == 0 || "unknown" . equalsIgnoreCase ( ip) ) {
ip = request. getHeader ( "X-Real-IP" ) ;
}
if ( ip == null || ip. length ( ) == 0 || "unknown" . equalsIgnoreCase ( ip) ) {
ip = request. getRemoteAddr ( ) ;
}
return "0:0:0:0:0:0:0:1" . equals ( ip) ? "127.0.0.1" : getMultistageReverseProxyIp ( ip) ;
}
public static boolean internalIp ( String ip) {
byte [ ] addr = textToNumericFormatV4 ( ip) ;
return internalIp ( addr) || "127.0.0.1" . equals ( ip) ;
}
private static boolean internalIp ( byte [ ] addr) {
if ( addr == null || addr. length < 2 ) {
return true ;
}
final byte b0 = addr[ 0 ] ;
final byte b1 = addr[ 1 ] ;
final byte SECTION_1 = 0x0A ;
final byte SECTION_2 = ( byte ) 0xAC ;
final byte SECTION_3 = ( byte ) 0x10 ;
final byte SECTION_4 = ( byte ) 0x1F ;
final byte SECTION_5 = ( byte ) 0xC0 ;
final byte SECTION_6 = ( byte ) 0xA8 ;
switch ( b0) {
case SECTION_1 :
return true ;
case SECTION_2 :
if ( b1 >= SECTION_3 && b1 <= SECTION_4 ) {
return true ;
}
case SECTION_5 :
switch ( b1) {
case SECTION_6 :
return true ;
}
default :
return false ;
}
}
public static byte [ ] textToNumericFormatV4 ( String text) {
if ( text. length ( ) == 0 ) {
return null ;
}
byte [ ] bytes = new byte [ 4 ] ;
String [ ] elements = text. split ( "\\." , - 1 ) ;
try {
long l;
int i;
switch ( elements. length) {
case 1 :
l = Long . parseLong ( elements[ 0 ] ) ;
if ( ( l < 0L ) || ( l > 4294967295L ) ) {
return null ;
}
bytes[ 0 ] = ( byte ) ( int ) ( l >> 24 & 0xFF ) ;
bytes[ 1 ] = ( byte ) ( int ) ( ( l & 0xFFFFFF ) >> 16 & 0xFF ) ;
bytes[ 2 ] = ( byte ) ( int ) ( ( l & 0xFFFF ) >> 8 & 0xFF ) ;
bytes[ 3 ] = ( byte ) ( int ) ( l & 0xFF ) ;
break ;
case 2 :
l = Integer . parseInt ( elements[ 0 ] ) ;
if ( ( l < 0L ) || ( l > 255L ) ) {
return null ;
}
bytes[ 0 ] = ( byte ) ( int ) ( l & 0xFF ) ;
l = Integer . parseInt ( elements[ 1 ] ) ;
if ( ( l < 0L ) || ( l > 16777215L ) ) {
return null ;
}
bytes[ 1 ] = ( byte ) ( int ) ( l >> 16 & 0xFF ) ;
bytes[ 2 ] = ( byte ) ( int ) ( ( l & 0xFFFF ) >> 8 & 0xFF ) ;
bytes[ 3 ] = ( byte ) ( int ) ( l & 0xFF ) ;
break ;
case 3 :
for ( i = 0 ; i < 2 ; ++ i) {
l = Integer . parseInt ( elements[ i] ) ;
if ( ( l < 0L ) || ( l > 255L ) ) {
return null ;
}
bytes[ i] = ( byte ) ( int ) ( l & 0xFF ) ;
}
l = Integer . parseInt ( elements[ 2 ] ) ;
if ( ( l < 0L ) || ( l > 65535L ) ) {
return null ;
}
bytes[ 2 ] = ( byte ) ( int ) ( l >> 8 & 0xFF ) ;
bytes[ 3 ] = ( byte ) ( int ) ( l & 0xFF ) ;
break ;
case 4 :
for ( i = 0 ; i < 4 ; ++ i) {
l = Integer . parseInt ( elements[ i] ) ;
if ( ( l < 0L ) || ( l > 255L ) ) {
return null ;
}
bytes[ i] = ( byte ) ( int ) ( l & 0xFF ) ;
}
break ;
default :
return null ;
}
} catch ( NumberFormatException e) {
return null ;
}
return bytes;
}
public static String getHostIp ( ) {
try {
return InetAddress . getLocalHost ( ) . getHostAddress ( ) ;
} catch ( UnknownHostException e) {
}
return "127.0.0.1" ;
}
public static String getHostName ( ) {
try {
return InetAddress . getLocalHost ( ) . getHostName ( ) ;
} catch ( UnknownHostException e) {
}
return "未知" ;
}
public static String getMultistageReverseProxyIp ( String ip) {
if ( ip != null && ip. indexOf ( "," ) > 0 ) {
final String [ ] ips = ip. trim ( ) . split ( "," ) ;
for ( String subIp : ips) {
if ( false == isUnknown ( subIp) ) {
ip = subIp;
break ;
}
}
}
return ip;
}
public static boolean isUnknown ( String checkString) {
return checkString == null || checkString. length ( ) == 0 || "unknown" . equalsIgnoreCase ( checkString) ;
}
}
三、测试