一、目标
通过注解的形式限制某个用户访问API的访问频率
二、案例代码
1)、定义注解
定义注解:用于自定义限制任意频次接口访问
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestLimit {
//限制访问次数
int limitCount() default 3;
//超时时间
long expireTime() default 60L;
//请求url
String requestUrl() default "";
}
定义注解:用于限制某个接口当天访问频次
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestDayLimit {
//限制访问次数
int limitCount() default 3;
//请求url
String requestUrl() default "";
}
2)、编写AOP
@Aspect
@Component
public class RequestLimitAspect {
@Autowired
private RedisTemplate redisTemplate;
@Autowired
private OpenJwtUtils openJwtUtils;
@Pointcut("@annotation(requestLimit)")
public void limit(RequestLimit requestLimit) {}
@Pointcut("@annotation(requestDayLimit)")
public void dayLimit(RequestDayLimit requestDayLimit) {}
@Around("limit(requestLimit)")
public Object requestLimitLog(ProceedingJoinPoint joinPoint, RequestLimit requestLimit) throws Throwable {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null;
TmsUserInfo tmsUserInfo =null;
if (attributes != null) {
request = attributes.getRequest();
//解析jwt token
tmsUserInfo = getTmsUserInfo(request);
}
if (request != null) {
String key ="requestLimit-"+requestLimit.requestUrl()+"-"+tmsUserInfo.getId();
//从redis获取值次数
Long count = redisTemplate.opsForValue().increment(key, 1);
if (count == 1) {
redisTemplate.expire(key, requestLimit.expireTime(), TimeUnit.SECONDS);
}
if (count >= requestLimit.limitCount()) {
//超过限制次数返回对应枚举值
throw new CommonException(CommonErrorCode.TOO_MANY_REQUESTS);
}
}
return joinPoint.proceed();
}
@Around("dayLimit(requestDayLimit)")
public Object requestDayLimitLog(ProceedingJoinPoint joinPoint, RequestDayLimit requestDayLimit) throws Throwable {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null;
TmsUserInfo tmsUserInfo =null;
if (attributes != null) {
request = attributes.getRequest();
//解析jwt token
tmsUserInfo = getTmsUserInfo(request);
}
if (request != null) {
String key ="requestDayLimit-"+requestDayLimit.requestUrl()+"-"+tmsUserInfo.getId();
//从redis获取值次数
Long count = redisTemplate.opsForValue().increment(key, 1);
if (count == 1) {
Date start = new Date();
DateTime endOfDay = DateUtil.endOfDay(start);
//当天还剩多少s
long betweenSecond = DateUtil.between(start, endOfDay, DateUnit.SECOND);
redisTemplate.expire(key, betweenSecond, TimeUnit.SECONDS);
}
if (count >= requestDayLimit.limitCount()) {
//超过限制次数返回对应枚举值
throw new CommonException(CommonErrorCode.DAILY_TOO_MANY_REQUESTS);
}
}
return joinPoint.proceed();
}
}
3)、测试案例
每天访问5次接口
@GetMapping("/dynamicDb")
@RequestDayLimit(requestUrl = "/dynamicDb", limitCount = 5)
public R sayHi(){
...
return R.ok();
}
三、案例改造
当并发的时候,从Redis拿次数,假设第一次拿,设置key成功了,然后超时导致设置ttl失败,这个数据将永远无法过期,要解决这个问题需要把设置key和加1作为一个原子操作,我们可以通过lua脚本去处理。
改造以后的案例:
@Aspect
@Component
public class RequestLimitAspect {
@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
private OpenJwtUtils openJwtUtils;
@Pointcut("@annotation(requestLimit)")
public void limit(RequestLimit requestLimit) {}
@Pointcut("@annotation(requestDayLimit)")
public void dayLimit(RequestDayLimit requestDayLimit) {}
@Around("limit(requestLimit)")
public Object requestLimitLog(ProceedingJoinPoint joinPoint, RequestLimit requestLimit) throws Throwable {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null;
TmsUserInfo tmsUserInfo =null;
if (attributes != null) {
request = attributes.getRequest();
//解析jwt token
tmsUserInfo = getTmsUserInfo(request);
}
if (request != null) {
String key ="requestLimit-"+requestLimit.requestUrl()+"-"+tmsUserInfo.getId();
ImmutableList<String> keys = ImmutableList.of(key);
String luaScript = buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
Number count = redisTemplate.execute(redisScript, keys, String.valueOf(requestLimit.limitCount()),String.valueOf(requestLimit.expireTime()));
if (count != null && count.intValue() <= requestLimit.limitCount()) {
return joinPoint.proceed();
} else {
throw new CommonException(CommonErrorCode.DAILY_TOO_MANY_REQUESTS);
}
}
return joinPoint.proceed();
}
@Around("dayLimit(requestDayLimit)")
public Object requestDayLimitLog(ProceedingJoinPoint joinPoint, RequestDayLimit requestDayLimit) throws Throwable {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null;
TmsUserInfo tmsUserInfo =null;
if (attributes != null) {
request = attributes.getRequest();
//解析jwt token
tmsUserInfo = getTmsUserInfo(request);
}
if (request != null) {
String key ="requestDayLimit-"+requestDayLimit.requestUrl()+"-"+tmsUserInfo.getId();
ImmutableList<String> keys = ImmutableList.of(key);
String luaScript = buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
Date start = new Date();
DateTime endOfDay = DateUtil.endOfDay(start);
//当天还剩多少s
long betweenSecond = DateUtil.between(start, endOfDay, DateUnit.SECOND);
Number count = redisTemplate.execute(redisScript, keys, String.valueOf(requestDayLimit.limitCount()),String.valueOf(betweenSecond));
if (count != null && count.intValue() <= requestDayLimit.limitCount()) {
return joinPoint.proceed();
} else {
throw new CommonException(CommonErrorCode.DAILY_TOO_MANY_REQUESTS);
}
}
return joinPoint.proceed();
}
/**
* 获取当前用户信息
* @param request 请求
* @return TmsUserInfo
*/
private TmsUserInfo getTmsUserInfo(HttpServletRequest request) {
TmsUserInfo tmsUserInfo =null;
Cookie cookie = ServletUtil.getCookie(request, CommonConstants.TMS_SOCIAL_TOKEN);
String jwtHeader = ServletUtil.getHeaderIgnoreCase(request, CommonConstants.TMS_AUTHORIZATION);
String bearer = StringUtils.remove(jwtHeader, "Bearer ");
if (cookie != null || StringUtils.isNotEmpty(bearer)) {
String jwt = cookie == null ? bearer : cookie.getValue();
Claims claims = openJwtUtils.parseJwt(jwt);
tmsUserInfo = BeanUtil.fillBeanWithMapIgnoreCase(claims, new TmsUserInfo(), false);
}
return tmsUserInfo;
}
/**
* Lua 原子操作统计
* @return String
*/
private String buildLuaScript() {
StringBuilder lua = new StringBuilder();
lua.append("local c");
lua.append("\nc = redis.call('GET',KEYS[1])");
// 调用超过最大值,则直接返回
lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
lua.append("\nreturn tonumber(c);");
lua.append("\nend");
// 执行计算器自加
lua.append("\nc = redis.call('incr',KEYS[1])");
lua.append("\nif tonumber(c) == 1 then");
// 从第一次调用开始限流,设置对应键值的过期
lua.append("\nredis.call('EXPIRE',KEYS[1],tonumber(ARGV[2]))");
lua.append("\nend");
lua.append("\nreturn c;");
return lua.toString();
}
}