spring boot中 用拦截器和注解的方式实现防刷策略
通过防刷策略实现对controller层相关功能的访问限制,比如限制同一ip地址在某一时段访问次数,限制给同一个手机号频繁发送短信
策略有很多条,其中一条策略内容如下:
{"id":1,
"strategy_topic":"sendsmsphone",
"strategy_type":"smsphonemin",
"max_time":3,
"exseconds":60
}
意思是发送短信,一个手机号60秒内可以最多发3条.
策略可以存在数据库或者redis缓存中供读取.
- strategy_topic策略主题(供程序做分支)一个主题对应多个类型,当一个主题生效后,其所有的类型都会生效,比如对于发短信策略,对应的类型有一分钟限制和半小时限制等.
- strategy_type策略类型(作为redis中的key的一部分)
- max_time过期时间内的限制次数
- exseconds策略过期时间
基本过程:
对于使用策略的方法(以限制短信为例),先根据strategy_topic到redis中查找有没有针对该手机号的内容,如果没有就增加一条String记录.key为"strategy_type:手机号"如"smsphonemin:188",value为1,过期时间来自exseconds. 如果已经存在这个key,且value没有超过max_time,那么value就赠一并返回true,如果超过max_time,就返回false.
最初我写的方法是在需要防刷的方法里加上这条语句:
if (!redisUtils.antiRush("sendsmsphone", userDto.getPhone())) {
throw new CustomException(CommonCode.ANTIRUSH);
}
但是对代码侵入性比较大, 后续维护不方便.于是改用注解+拦截器的方式,对于需要防刷的方法,直接加上注解,并将strategy_topic写入value即可.如果是短信防刷拦截器会自动从方法的参数里获取phone,如果是ip防刷拦截器会自动获取客户端IP地址.
先要创建spring boot项目, 配置redis服务器, 如果需要的话还有连接mysql数据库, 这里把连接数据库省略了, 策略内容直接存在代码里了
具体代码如下:
redis工具类, 包含对redis的操作:
package com.lin.utils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.BoundListOperations;
import org.springframework.data.redis.core.BoundValueOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* redis工具类,使用spring boot框架, 需要redis服务器并在配置文件中配置服务器信息
*/
@Component
public class RedisUtils {
@Autowired
RedisTemplate<String, String> redisTemplate;
/**
* 设置String类型的key,value, 无论存在不存在
* @param key
* @param value
* @param EXseconds
* @return
*/
public String setValue(String key, String value, long EXseconds) {
BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
ops.set(value);
ops.expire(EXseconds, TimeUnit.SECONDS);
return key;
}
/**
* 根据key,获取String类型的value值
* @param key
* @return
*/
public String getValue(String key) {
BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
String value = ops.get();
return value;
}
/**
* 根据key,获取List类型的所有内容
* @param key
* @return
*/
public List<String> getAllList(String key) {
BoundListOperations<String, String> stringStringBoundListOperations = redisTemplate.boundListOps(key);
List<String> range = stringStringBoundListOperations.range(0, -1);
return range;
}
/**
* 输入List类型的key,判断contain是否在该List中
* @param key
* @param contain
* @return
*/
public boolean containsInList(String key,String contain){
return getAllList(key).contains(contain);
}
/**
* 将key的值赠一
* @param key
* @return
*/
public Long incrValue(String key) {
BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
Long increment = ops.increment();
return increment;
}
/**
* 判断是否存在key值
* @param key
* @return
*/
public boolean haskey(String key) {
return redisTemplate.hasKey(key);
}
}
防刷工具类:
package com.lin.utils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
public class AntiRushUtils {
@Autowired
private RedisUtils redisUtils;
/**
* 单个策略: 先判断redis中有没有相应的值, 有: 如果超限就返回false, 没超限就自增一后返回true. 没有: set value 后返回true
* @param strateType 策略类型 如:smsmin, smshalfhour, loginipmin, loginipmin等
* @param strateContent 策略内容 如:手机号, api名+ip地址, api名+userid等, "strateType:strateContent"作为redis中的key
* @param maxTimes 时间段内允许的最高访问次数
* @param EXSeconds 时间段, 单位秒
* @return true: 放行 false: 拦截
*/
private boolean strategySingle(String strateType, String strateContent, Integer maxTimes, Long EXSeconds) {
String key = strateType + ":" + strateContent;
String value = redisUtils.getValue(key);
if (value != null) {
System.out.println("触发策略:"+key);
if (Integer.parseInt(value) >= maxTimes) {
// 超出最大值限制
System.out.println("not pass:");
return false;
}else{
// 计数器加一
System.out.println("pass");
redisUtils.incrValue(key);
return true;
}
}
// 内存尚无记录, 添加计数器, 开始计时
redisUtils.setValue(key,"1",EXSeconds);
return true;
}
/**
* 多个策略组合使用, 循环遍历单个策略方法
* 注意, 策略要按照EXSeconds 从小到大的顺序存入List中, 如: 按分钟计算就比按小时计算先放进list中
* 如果大策略拦截了, 但是小策略可以计数, 如果用户狂刷的话, 有可能出现大策略放行, 但是小策略仍然有计数的情况,最终还是拦截
* @param strategies 包含字符数组的List
* 字符数组4项内容: strateType 策略类型 strateContent 策略内容 maxTimes 时间段内允许的最高访问次数 EXSeconds 时间段, 单位秒
* @return true: 放行 false: 拦截
*/
public boolean strategyLoop(List<String[]> strategies, String strategyContent){
if (strategies == null){
}
for (String[] strategy : strategies) {
String strategyType = strategy[0];
int maxTimes = Integer.parseInt(strategy[1]);
long EXSeconds = Long.parseLong(strategy[2]);
if (!strategySingle(strategyType, strategyContent, maxTimes, EXSeconds)){
return false;
}
}
return true;
}
}
防刷注解类:
package com.lin.annotations;
import java.lang.annotation.*;
/**
* 防刷策略注解 用在需要防刷的controller层方法上
* value值对应数据库中的strategy_topic,必填
* 如果请求方式为GET,只取第一个参数
* 多个策略可以使用多重注解, 要放在AntiRushChecks中使用
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Repeatable(AntiRushChecks.class)
public @interface AntiRushCheck {
boolean requried() default true;
String value();
}
多重注解容器:
package com.lin.annotations;
import java.lang.annotation.*;
/**
* java的多重注解容器(重复注解)
* 多个注解存在该容器中,用法:
* @AntiRushChecks({
@AntiRushCheck("sendsmsphone"),
@AntiRushCheck("sendsmsip")
})
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface AntiRushChecks {
AntiRushCheck[] value();
}
防刷拦截器:
package com.lin.interceptors;
import com.alibaba.fastjson.JSONObject;
import com.lin.annotations.AntiRushCheck;
import com.lin.annotations.AntiRushChecks;
import com.lin.exception.CustomException;
import com.lin.response.CommonCode;
import com.lin.utils.AntiRushUtils;
import com.lin.utils.RedisUtils;
import com.lin.utils.SystemUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.StreamUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.*;
@Component
public class AntiRushInterceptor implements HandlerInterceptor {
@Autowired
RedisUtils redisUtils;
@Autowired
AntiRushUtils antiRushUtils;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
HandlerMethod handlerMethod = (HandlerMethod) handler;
Method method = handlerMethod.getMethod();
boolean mutilChecks = method.isAnnotationPresent(AntiRushChecks.class);
boolean singleCheck = method.isAnnotationPresent(AntiRushCheck.class);
if (mutilChecks){
// 多重注解
AntiRushChecks annotation = method.getAnnotation(AntiRushChecks.class);
AntiRushCheck[] values = annotation.value();
for (AntiRushCheck antiRushCheck : values) {
oneAntiRushCheck(antiRushCheck, request);
}
}else if(singleCheck){
// 单一注解
oneAntiRushCheck(method.getAnnotation(AntiRushCheck.class), request);
}
return true;
}
private void oneAntiRushCheck(AntiRushCheck antiRushCheck, HttpServletRequest request) {
String strategyTopic = antiRushCheck.value(); // 注解中的value值
String strategyType = null;
if (strategyTopic.endsWith("ip")) {
// 获取客户端ip地址 具体方法略过请自行百度
strategyType = "127.0.0.1" ;
}else {
// 目前只有两类防刷: 1.IP 2,手机号, 以下代码获取方法中的参数phone
Map<String, String[]> parameterMap = request.getParameterMap();
if (parameterMap.size()==0){
// 请求类型是POST
String requestBody = null;
try {
byte[] bytes = StreamUtils.copyToByteArray(request.getInputStream());
requestBody = new String(bytes, request.getCharacterEncoding());
} catch (Exception e) {
e.printStackTrace();
}
// 获取请求参数中的phone值
strategyType = JSONObject.parseObject(requestBody).get("phone").toString();
}
else{
// parameterMap不为空,说明请求类型是GET
Collection<String[]> valuesList = parameterMap.values();
for (String[] strs : valuesList) {
// 获取被拦截的方法的参数值, 只有一个参数
strategyType = strs[0];
}
if (strategyType == null) {
// 注解中的value值为空
throw new CustomException(CommonCode.INVALID_PARAM);
}
}
}
boolean send = antiRushUtils.strategyLoop(this.getStrategy().get(strategyTopic), strategyType);
// 只要有一个策略不通过就直接抛出异常
if (!send) {
throw new CustomException(CommonCode.ANTIRUSH);
}
}
/**
* 获取具体防刷策略, 此处存在map中, 生产环境要从缓存或数据库中获取
* @return
*/
private Map<String, List> getStrategy() {
Map<String, List> map = new HashMap<>();
List<String[]> strategyModelsSms = new ArrayList<>();
strategyModelsSms.add(new String[]{"smsphonemin", "2", "60"});
strategyModelsSms.add(new String[]{"smsphonehalfhour", "5", "1800"});
map.put("sendsmsphone", strategyModelsSms);
List<String[]> strategyModelsSmsIP = new ArrayList<>();
strategyModelsSmsIP.add(new String[]{"smsipmin", "2", "60"});
strategyModelsSmsIP.add(new String[]{"smsiphalfhour", "5", "1800"});
map.put("sendsmsip", strategyModelsSms);
return map;
}
}
由于拦截器获取POST请求request输入流后, 输入流指针指向末尾, 被拦截的方法无法获取参数内容, 需要将输入流拷贝并通过重写ServletInputStream的GetInputStream()方法对外提供数据.
我参考的是这个文章:
https://www.cnblogs.com/keeya/p/13634015.html#servletinputstreamcoyoteinputstream-%E8%BE%93%E5%85%A5%E6%B5%81%E6%97%A0%E6%B3%95%E9%87%8D%E5%A4%8D%E8%B0%83%E7%94%A8
涉及两个方法:
NewHttpServletRequestWrapper
package com.lin.entity;
import org.springframework.util.StreamUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
/**
* 自定义 HttpServletRequestWrapper 来包装输入流
*/
public class NewHttpServletRequestWrapper extends HttpServletRequestWrapper {
/**
* 缓存下来的HTTP body
*/
private byte[] body;
public NewHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
body = StreamUtils.copyToByteArray(request.getInputStream());
}
/**
* 重新包装输入流
* @return
* @throws IOException
*/
@Override
public ServletInputStream getInputStream() throws IOException {
InputStream bodyStream = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bodyStream.read();
}
/**
* 下面的方法一般情况下不会被使用,如果你引入了一些需要使用ServletInputStream的外部组件,可以重点关注一下。
* @return
*/
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
@Override
public BufferedReader getReader() throws IOException {
InputStream bodyStream = new ByteArrayInputStream(body);
return new BufferedReader(new InputStreamReader(getInputStream()));
}
}
NewDispatcherServlet
package com.lin.entity;
import org.springframework.web.servlet.DispatcherServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* 自定义 DispatcherServlet 来分派 NewHttpServletRequestWrapper
*/
public class NewDispatcherServlet extends DispatcherServlet{
/**
* 包装成我们自定义的request
* @param request
* @param response
* @throws Exception
*/
@Override
protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
super.doDispatch(new NewHttpServletRequestWrapper(request), response);
}
}
Configuration类
package com.lin.interceptors;
import com.lin.entity.NewDispatcherServlet;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class InterceptorConfig implements WebMvcConfigurer {
@Autowired
AntiRushInterceptor antiRushInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 声明自定义的拦截器对象和要拦截的请求路径
registry.addInterceptor(antiRushInterceptor) .addPathPatterns("/api/**") ;
}
@Bean
@Qualifier(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME)
public DispatcherServlet dispatcherServlet() {
return new NewDispatcherServlet();
}
}
还需要一些常规的类:
- 自定义异常类CustomException
- 自定义返回值类ResponseResult
- CommonCode类
- -ResultCode接口
- UserDto实体类用于传参,其中要有phone属性传递要发短信的手机号
上述代码都属于常规代码,在此略过
接下来,在项目的controller类中需要使用防刷策略的方法上加上注解@AntiRushCheck(“策略主题”)就可以了.
如果想对一个方法实现多个策略,用@AntiRushChecks容器包含多个@AntiRushCheck即可,具体格式如下:
@AntiRushChecks({
@AntiRushCheck("sendsmsphone"),
@AntiRushCheck("sendsmsip")
})
头一次写这么长的, 希望多多提意见!