项目地址 redis-limit
在我们程序中,有时候需要对一些接口做访问控制,使程序更稳定,最常用的一种是通过ip限制,还有一种是通过用户名限制,也可以把一个接口限制死,在一段时间内只能访问多少次,这个根据自己需求来,不固定。在需要做限制的方法上加上一个自定义注解,用aop获取到这个方法,利用redis中的incr方法,去计数访问次数,超过访问次数,return一个自定义异常。
在代码中写的有注释,可以仔细看一下
前期准备工作
1.依赖
springboot项目,这里就只添加redis、jedis、和连接池的,其他的自行添加,版本也可以自行修改
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>javax.validation</groupId>
<artifactId>validation-api</artifactId>
<version>1.1.0.Final</version>
</dependency>
<dependency>
<groupId>org.hibernate</groupId>
<artifactId>hibernate-validator</artifactId>
<version>5.4.1.Final</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.3.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<!-- google guava -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>27.0-jre</version>
</dependency>
2.自定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
/**
* 资源名称,用于描述接口功能
*/
String name() default "";
/**
* 资源 key
*/
String key() default "";
/**
* key prefix
*/
String prefix() default "";
/**
* 时间范围,单位秒
*/
int period();
/**
* 限制访问次数
*/
int count();
/**
* 限制类型
*/
LimitType limitType() default LimitType.IP;
}
public enum LimitType {
/**
* 用户名
*/
CUSTOMER,
/**
* 根据 IP地址限制
*/
IP
}
3. Redis序列化配置
package com.lichong.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
/**
* @author lichong
*/
@Configuration
public class RedisConfigure {
@Bean
@ConditionalOnClass(RedisOperations.class)
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper mapper = new ObjectMapper();
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
mapper.activateDefaultTyping(mapper.getPolymorphicTypeValidator(), ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(mapper);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// key采用 String的序列化方式
template.setKeySerializer(stringRedisSerializer);
// hash的 key也采用 String的序列化方式
template.setHashKeySerializer(stringRedisSerializer);
// value序列化方式采用 jackson
template.setValueSerializer(jackson2JsonRedisSerializer);
// hash的 value序列化方式采用 jackson
template.setHashValueSerializer(jackson2JsonRedisSerializer);
template.afterPropertiesSet();
return template;
}
}
4.全局异常捕捉
package com.lichong.exception;
/**
* 自定义限流异常
*
* @author lichong
*/
public class LimitAccessException extends RuntimeException {
private static final long serialVersionUID = -3608667856397125671L;
public LimitAccessException(String message) {
super(message);
}
}
package com.lichong.handler;
import com.lichong.common.ResponseVO;
import com.lichong.exception.LimitAccessException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
/**
* @author lichong
*/
@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(value = Exception.class)
public ResponseVO handleException(Exception e) {
log.error("系统内部异常,异常信息", e);
return new ResponseVO().code(HttpStatus.INTERNAL_SERVER_ERROR).message("系统内部异常");
}
@ExceptionHandler(value = LimitAccessException.class)
public ResponseVO handleLimitAccessException(LimitAccessException e) {
log.debug("LimitAccessException", e);
return new ResponseVO().code(HttpStatus.TOO_MANY_REQUESTS).message(e.getMessage());
}
}
5.工具类(1.获取ip 2.获取HttpServletRequest)
package com.lichong.util;
import javax.servlet.http.HttpServletRequest;
import static cn.hutool.core.net.NetUtil.isInnerIP;
/**
* @author lichong
*/
public class IPUtil {
private static final String UNKNOWN = "unknown";
protected IPUtil(){
}
/**
* 获取 IP地址
* 使用 Nginx等反向代理软件, 则不能通过 request.getRemoteAddr()获取 IP地址
* 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,
* X-Forwarded-For中第一个非 unknown的有效IP字符串,则为真实IP地址
*
* ip地址能伪造,不能百分之百确定真实ip
*/
public static String getIpAddr(HttpServletRequest request) {
// 获取客户端ip地址
String clientIp = request.getHeader("x-forwarded-for");
if (clientIp == null || clientIp.length() == 0 || "unknown".equalsIgnoreCase(clientIp)) {
clientIp = request.getHeader("Proxy-Client-IP");
}
if (clientIp == null || clientIp.length() == 0 || "unknown".equalsIgnoreCase(clientIp)) {
clientIp = request.getHeader("WL-Proxy-Client-IP");
}
if (clientIp == null || clientIp.length() == 0 || "unknown".equalsIgnoreCase(clientIp)) {
clientIp = request.getRemoteAddr();
}
/*
* 对于获取到多ip的情况下,找到公网ip.
*/
String sIP = null;
if (clientIp != null && !clientIp.contains("unknown") && clientIp.indexOf(",") > 0) {
String[] ipsz = clientIp.split(",");
for (String anIpsz : ipsz) {
if (!isInnerIP(anIpsz.trim())) {
sIP = anIpsz.trim();
break;
}
}
/*
* 如果多ip都是内网ip,则取第一个ip.
*/
if (null == sIP) {
sIP = ipsz[0].trim();
}
clientIp = sIP;
}
if (clientIp != null && clientIp.contains("unknown")){
clientIp =clientIp.replaceAll("unknown,", "");
clientIp = clientIp.trim();
}
if ("".equals(clientIp) || null == clientIp){
clientIp = "127.0.0.1";
}
return clientIp;
}
}
package com.lichong.util;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.Objects;
/**
* @author lichong
* 在除了controller以外的地方想要获取到 httpServletRequest 可以通过 RequestContextHolder 获取,
*/
public class HttpContextUtil {
private HttpContextUtil(){
}
public static HttpServletRequest getHttpServletRequest() {
return ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
}
}
6.反射support,用来获取joinpoint下的方法。
package com.lichong.aspect;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import java.lang.reflect.Method;
/**
* @author lichong
* 反射获取到当前joinpoint的方法
*/
public abstract class AspectSupport {
Method resolveMethod(ProceedingJoinPoint point) {
MethodSignature signature = (MethodSignature)point.getSignature();
Class<?> targetClass = point.getTarget().getClass();
Method method = getDeclaredMethod(targetClass, signature.getName(),
signature.getMethod().getParameterTypes());
if (method == null) {
throw new IllegalStateException("无法解析目标方法: " + signature.getMethod().getName());
}
return method;
}
/**
* 反射获取有参方法
* @param clazz clazz
* @param name name
* @param parameterTypes parameterTypes
* @return Method
*/
private Method getDeclaredMethod(Class<?> clazz, String name, Class<?>... parameterTypes) {
try {
return clazz.getDeclaredMethod(name, parameterTypes);
} catch (NoSuchMethodException e) {
Class<?> superClass = clazz.getSuperclass();
if (superClass != null) {
return getDeclaredMethod(superClass, name, parameterTypes);
}
}
return null;
}
}
7.配置文件
# Redis服务器连接端口
spring.redis.port=6379
# Redis服务器地址
spring.redis.host=127.0.0.1
# Redis数据库索引(默认为0)
spring.redis.database=0
# Redis服务器连接密码(默认为空)
spring.redis.password=
# 连接池最大连接数(使用负值表示没有限制)
spring.redis.jedis.pool.max-active=8
# 连接池最大阻塞等待时间(使用负值表示没有限制)
spring.redis.jedis.pool.max-wait=-1ms
# 连接池中的最大空闲连接
spring.redis.jedis.pool.max-idle=8
# 连接池中的最小空闲连接
spring.redis.jedis.pool.min-idle=0
# 连接超时时间(毫秒)
spring.redis.timeout=5000ms
准备完毕,编写aspect与接口
package com.lichong.aspect;
import com.google.common.collect.ImmutableList;
import com.lichong.annotation.Limit;
import com.lichong.annotation.LimitType;
import com.lichong.exception.LimitAccessException;
import com.lichong.util.HttpContextUtil;
import com.lichong.util.IPUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
/**
* 接口限流
*
* @author lichong
*/
@Slf4j
@Aspect
@Component
public class LimitAspect extends AspectSupport {
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Pointcut("@annotation(com.lichong.annotation.Limit)")
public void pointcut() {
}
@Around("pointcut()")
public Object around(ProceedingJoinPoint point) throws Throwable {
// 获取到 HttpServletRequest
HttpServletRequest request = HttpContextUtil.getHttpServletRequest();
Method method = resolveMethod(point);
// 获取到注解
Limit limitAnnotation = method.getAnnotation(Limit.class);
// CUSTOMER 还是 IP
LimitType limitType = limitAnnotation.limitType();
// 在 redis 中需要使用这个 name 拼接 key
String name = limitAnnotation.name();
String key;
// 获取客户端ip
String ip = IPUtil.getIpAddr(request);
int limitPeriod = limitAnnotation.period();
int limitCount = limitAnnotation.count();
switch (limitType) {
case IP:
key = limitAnnotation.key() + ip;
break;
case CUSTOMER:
key = limitAnnotation.key();
break;
default:
key = StringUtils.upperCase(method.getName());
}
// redis 通过key来区分唯一
ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limitAnnotation.prefix() + "_", key));
String luaScript = buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
Number count = redisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
log.info("IP:{} 第 {} 次访问key为 {},描述为 [{}] 的接口", ip, count, keys, name);
if (count != null && count.intValue() <= limitCount) {
return point.proceed();
} else {
throw new LimitAccessException("访问频率过高,请一分钟后再试");
}
}
/**
* 限流脚本 参考redis文档 http://doc.redisfans.com/script/eval.html
* 调用的时候不超过阈值,则直接返回并执行计算器自加。
*
* @return lua脚本
*/
private String buildLuaScript() {
return "local c" +
"\n c = redis.call('get',KEYS[1])" +
"\n if c and tonumber(c) > tonumber(ARGV[1]) then" +
"\n return c;" +
"\n end" +
"\n c = redis.call('incr',KEYS[1])" +
"\n if tonumber(c) == 1 then" +
"\n redis.call('expire',KEYS[1],ARGV[2])" +
"\n end" +
"\n return c;";
}
}
package com.lichong.controller;
import com.lichong.annotation.Limit;
import com.lichong.common.ResponseVO;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@Slf4j
public class Test {
/**
* login接口在60秒内只能被一个ip地址访问5次
*/
@RequestMapping("login")
@Limit(key = "login", period = 60, count = 5, name = "登录接口", prefix = "limit")
public ResponseVO login(@RequestParam("username") String username, @RequestParam("password") String password) {
if ("admin".equals(username) && "12345".equals(password)) {
return new ResponseVO().success().message("登录成功");
} else {
return new ResponseVO().fail().message("登录失败");
}
}
}
测试
第一次访问
在第六次访问的时候,超过了我们的限制,抛出自定义异常被全局捕捉,访问失败。
redis+aop实现接口访问限制功能完成
项目地址:redis-limit