@Aspect
@Component
@Order(1)
public class RateLimiterAspect {
private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
private RedisTemplate<Object, Object> redisTemplate;
private RedisScript<Long> limitScript;
@Autowired
public void setRedisTemplate1(RedisTemplate<Object, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Autowired
public void setLimitScript(RedisScript<Long> limitScript) {
this.limitScript = limitScript;
}
@Autowired
private RedisCache redisCache;
@Pointcut(value = "execution(public * com.itcast.*.controller..*(..))")
public void rateLimiter() {
}
@Before("rateLimiter()")
public void doBefore(JoinPoint point){
String requestKey = getRequestKey(point);
String combineKey = getCombineKey(requestKey);
HashMap<String, Integer> limiterMap = getLimiterInfo(requestKey);
int time = limiterMap.get("time");
int count =limiterMap.get("count");
List<Object> keys = Collections.singletonList(combineKey);
try {
Long number = redisTemplate.execute(limitScript, keys, limiterMap.get("count"), time);
if (StringUtils.isNull(number) || number.intValue() > count) {
insertRateLimiterLog(combineKey,time,count);
throw new ServiceException("访问过于频繁,请稍候再试");
}
log.info("限制请求'{}次',当前请求'{}次',缓存key'{}'", count, number.intValue(), combineKey);
} catch (ServiceException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
/**
* 获取请求接口(接口类名-方法名)
* @param point 切点
* @return 接口类名-方法名
*/
public String getRequestKey(JoinPoint point) {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> decClass = method.getDeclaringClass();
return decClass.getName() + "-" + method.getName();
}
/**
* 获取接口限流key
* @param requestKey 接口类名-方法名
* @return
*/
public String getCombineKey(String requestKey) {
return RATE_LIMIT_KEY + IpUtils.getIpAddr(ServletUtils.getRequest()) + "-" + requestKey;
}
/**
* 获取限流配置信息
* @param requestKey 接口类名-方法名
* @return map
*/
public HashMap<String, Integer> getLimiterInfo(String requestKey) {
HashMap<String, Integer> map = new HashMap<>();
HashMap<String, Integer> configFromDB = getConfigFromDB(requestKey);
if (!CollectionUtils.isEmpty(configFromDB)){
return configFromDB;
}
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
assert attributes != null;
String method = attributes.getRequest().getMethod();
if("GET".equals(method)){
map.put("time", LOOSE_RATE_LIMITER_TIME);
map.put("count", LOOSE_RATE_LIMITER_COUNT);
}else {
map.put("time", STRICT_RATE_LIMITER_TIME);
map.put("count", STRICT_RATE_LIMITER_COUNT);
}
return map;
}
/**
* 将拦截信息写入数据库
* @param combineKey
* @param time
* @param count
*/
public void insertRateLimiterLog(String combineKey,int time,int count){
Date nowDate = DateUtils.getNowDate();
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
String limiterTime = sdf.format(nowDate);
new SqlRunner().insert("insert into rate_limiter_log (combine_key,time,count,limiter_time) values ( '"+ combineKey +"',"+time+","+count+",'"+limiterTime+"')");
}
/**
* 从缓存中获取限流配置信息,不存在则查询数据库,并放入缓存中(有效期5分钟)
* @param requestKey 接口类名-方法名
* @return map
*/
public HashMap<String, Integer> getConfigFromDB(String requestKey){
HashMap<String, Integer> map = new HashMap<>();
String redisKey = RATE_LIMITER_KEY +requestKey;
Map<String, Object> cacheMap = redisCache.getCacheMap(redisKey);
if (cacheMap.containsKey(redisKey)){
return (HashMap<String, Integer>) cacheMap.get("requestKey");
}else {
Map<String, Object> resultMap = new SqlRunner().selectOne("select time,count from rate_limiter_config where del_flag =0 and requestKey ='" + requestKey + "'");
if (ObjectUtils.isNotEmpty(resultMap)) {
map.put("time", Integer.parseInt(resultMap.get("time").toString()));
map.put("count", Integer.parseInt(resultMap.get("count").toString()));
redisCache.setCacheMap(redisKey,map);
redisCache.expire(redisKey, 5*60);
}
}
return map;
}
}
* redis配置
*/
@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport
{
@Bean
@SuppressWarnings(value = { "unchecked", "rawtypes" })
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory)
{
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
FastJson2JsonRedisSerializer serializer = new FastJson2JsonRedisSerializer(Object.class);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(serializer);
// Hash的key也采用StringRedisSerializer的序列化方式
template.setHashKeySerializer(new StringRedisSerializer());
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
@Bean
public DefaultRedisScript<Long> limitScript()
{
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(limitScriptText());
redisScript.setResultType(Long.class);
return redisScript;
}
/**
* 限流脚本
*/
private String limitScriptText()
{
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
}