引入依赖
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>io.lettuce</groupId>
<artifactId>lettuce-core</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
</dependencies>
application.properties 增加redis配置
server.port=8081
spring.redis.host=192.168.0.103
spring.redis.port=6379
spring.redis.jedis.pool.max-active=8
spring.redis.jedis.pool.max-wait=1
spring.redis.timeout=10000
编写Lua脚本
vim rateLimit.lua
:wq
放置在resources目录下
local num = redis.call('incr', KEYS[1])
if tonumber(num) == 1 then
redis.call('expire', KEYS[1], ARGV[1])
return 1
elseif tonumber(num) > tonumber(ARGV[2]) then
return 0
else
return 1
end
注:1、我们通过keys【1】获取传入的key参数
2、通过ARGV[1]获取传入的limit参数
3、redis.call方法,从缓存中get和key相关的值,如果为nil那么返回04、接着判断返回中记录的数值是否会大于限制大小,如果超出表示该被限流,返回05、如果未超过,那么该key的缓存值+1,并设置过期时间为1秒钟以后,并返回返回值+1
限流注解
package com.annotation;
import java.lang.annotation.*;
@Target(ElementType.METHOD)
@Documented
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* 限流唯一标示
* @return
*/
String key() default "";
/**
* 限流时间
* @return
*/
int time();
/**
* 限流次数
* @return
*/
int count();
公共配置 RedisTemplate配置
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import java.io.Serializable;
@Component
public class Commons {
@Bean
public DefaultRedisScript<Number> defaultRedisScript(){
DefaultRedisScript<Number> defaultRedisScript=new DefaultRedisScript();
defaultRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("rateLimit.lua")));
defaultRedisScript.setResultType(Number.class);
return defaultRedisScript;
}
@Bean
public RedisTemplate<String, Serializable> redisTemplate(LettuceConnectionFactory connectionFactory){
RedisTemplate<String,Serializable> redisTemplate=new RedisTemplate<>();
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(new GenericJackson2JsonRedisSerializer());
redisTemplate.setConnectionFactory(connectionFactory);
return redisTemplate;
}
}
利用AOP编写拦截器
import com.annotation.RateLimit;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
@Aspect
@Configuration
public class RateAspect {
Logger logger=LoggerFactory.getLogger(RateLimit.class);
@Autowired
private RedisTemplate redisTemplate;
@Autowired
private DefaultRedisScript defaultRedisScript;
@Around("execution(* com.controller..*(..))")
public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
System.out.println("执行");
MethodSignature methodSignature = (MethodSignature)joinPoint.getSignature();
Method method = methodSignature.getMethod();
Class<?> declaringClass = method.getDeclaringClass();
RateLimit rateLimit = method.getAnnotation(RateLimit.class);
System.out.println(rateLimit);
if(rateLimit!=null){
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
String ipAddress=getIpAddress(request);
System.out.println(ipAddress);
StringBuffer stringBuffer=new StringBuffer();
stringBuffer.append(ipAddress)
.append("-")
.append(declaringClass.getName())
.append("-")
.append(method.getName())
.append("-")
.append(rateLimit.key());
System.out.println(stringBuffer.toString());
List<String> strings = Collections.singletonList(stringBuffer.toString());
Number number = (Number) redisTemplate.execute(defaultRedisScript, strings, rateLimit.count(), rateLimit.time());
if(number!=null&&number.intValue()>0&&number.intValue()<=rateLimit.time()){
logger.info("限制时间段内访问:{}次",number);
return joinPoint.proceed();
}
}
else{
return joinPoint.proceed();
}
throw new RuntimeException("已经到限制限流次数");
}
private String getIpAddress(HttpServletRequest request) {
String ipAddr=null;
try{
ipAddr=request.getHeader("X-FORWARDED-FOR");
if(ipAddr==null||ipAddr.length()==0||"unknown".equalsIgnoreCase(ipAddr)){
ipAddr=request.getHeader("Proxy-Client-IP");
}
if(ipAddr!=null&&ipAddr.length()>15){
ipAddr = ipAddr.substring(0, ipAddr.indexOf(","));
}
}catch(Exception ex){
ipAddr=null;
}
return ipAddr;
}
}
编写用户请求
注:添加@RateLimit注解,会在Redis中生成10秒中,可以访问10次的key
@Controller
public class LuaController {
@Autowired
private RedisTemplate redisTemplate;
@RateLimit(key = "test",time = 10,count = 10)
@GetMapping("/lua")
@ResponseBody
public String luaTimer(){
RedisAtomicInteger integer=new RedisAtomicInteger("entryIdCounter",redisTemplate.getConnectionFactory());
String format = DateFormatUtils.format(new Date(), "yyyy-MM-dd HH:mm:ss.SS");
return format+"累计访问次数:"+integer.getAndIncrement();
}
}
启动服务
10秒中,可以访问10次超过10次,页面就报错,等够10秒,重新计算