Java接口访问限制次数(使用IP作为唯一标识)
1、获取用户IP工具类
public static String getIp(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
String comma = ",";
String localhost = "127.0.0.1";
if (ip.contains(comma)) {
ip = ip.split(",")[0];
}
if (localhost.equals(ip)) {
try {
ip = InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
e.printStackTrace();
}
}
return ip;
}
2、接口限制注解(切面)
(1)controller层
@LimitRequest(time = "${limit.request.time}", count = "${limit.request.count}", isFlag = true)
(2)注解类
import java.lang.annotation.*;
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LimitRequest {
String time() default "1";
String count() default "100";
boolean isFlag() default false;
}
(3)切面方法
import com.alibaba.fastjson.JSON;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanExpressionContext;
import org.springframework.beans.factory.config.BeanExpressionResolver;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.Scope;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.StringValueResolver;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.concurrent.TimeUnit;
@Aspect
@Component
public class LimitRequestAspect {
private static final Logger logger = LoggerFactory.getLogger(LimitRequestAspect.class);
@Autowired
private StringRedisTemplate redisTemplate;
private final BeanExpressionContext exprContext;
private final BeanExpressionResolver exprResolver;
public LimitRequestAspect(ConfigurableBeanFactory beanFactory) {
this.exprContext = new BeanExpressionContext(beanFactory, (Scope)null);
this.exprResolver = beanFactory.getBeanExpressionResolver();
}
public String resolveStringValue(String strVal) {
String value = this.exprContext.getBeanFactory().resolveEmbeddedValue(strVal);
if (this.exprResolver != null && value != null) {
Object evaluated = this.exprResolver.evaluate(value, this.exprContext);
value = evaluated != null ? evaluated.toString() : null;
}
return value;
}
@Pointcut("@annotation(limitRequest)")
public void excudeService(LimitRequest limitRequest) {
}
@Around("excudeService(limitRequest)")
public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
RequestAttributes ra = RequestContextHolder.getRequestAttributes();
ServletRequestAttributes sra = (ServletRequestAttributes) ra;
HttpServletRequest request = sra.getRequest();
String ipStr = request.getHeader("x-forwarded-for");
if (StringUtils.isBlank(ipStr) || "unknown".equalsIgnoreCase(ipStr)) {
ipStr = request.getHeader("Proxy-Client-IP");
}
if (StringUtils.isBlank(ipStr) || "unknown".equalsIgnoreCase(ipStr)) {
ipStr = request.getHeader("WL-Proxy-Client-IP");
}
if (StringUtils.isBlank(ipStr) || "unknown".equalsIgnoreCase(ipStr)) {
ipStr = request.getRemoteAddr();
}
String userId = null;
Object[] args = pjp.getArgs();
for (Object arg : args) {
if (arg instanceof DcbUserDTO) {
DcbUserDTO dcbUserDTO = (DcbUserDTO) arg;
userId = dcbUserDTO.getUserId();
}
}
if (StringUtils.isBlank(userId)) {
userId = request.getParameter("userId");
}
logger.info("===========切面获取的用户id:"+userId);
String key = StringUtils.join(request.getRequestURI(),"-",ipStr,"-",userId);
logger.info("===========切面获取的唯一key:"+key);
String redisValue = redisTemplate.opsForValue().get(key);
Integer count = 0;
if (StringUtils.isNotBlank(redisValue)){
count = JSON.parseObject(redisValue, Integer.class);
}
Integer totalCount;
Long time;
if (limitRequest.isFlag()){
totalCount = Integer.valueOf(this.resolveStringValue(limitRequest.count()));
time = Long.valueOf(this.resolveStringValue(limitRequest.time()));
}else {
totalCount = Integer.valueOf(limitRequest.count());
time = Long.valueOf(limitRequest.time());
}
if (count >= totalCount){
return ResultDTO.failure("接口访问次数超过限制,请"+time+"分钟后重试");
} else if (count == 0){
redisTemplate.opsForValue().set(key,JSON.toJSONString(1), time, TimeUnit.MINUTES);
} else {
redisTemplate.opsForValue().increment(key);
}
Object result = pjp.proceed();
return result;
}
}