spring boot中 用拦截器和注解的方式实现防刷策略

spring boot中 用拦截器和注解的方式实现防刷策略

通过防刷策略实现对controller层相关功能的访问限制,比如限制同一ip地址在某一时段访问次数,限制给同一个手机号频繁发送短信

策略有很多条,其中一条策略内容如下:

{"id":1,
"strategy_topic":"sendsmsphone",
"strategy_type":"smsphonemin",
"max_time":3,
"exseconds":60
}

意思是发送短信,一个手机号60秒内可以最多发3条.
策略可以存在数据库或者redis缓存中供读取.

  • strategy_topic策略主题(供程序做分支)一个主题对应多个类型,当一个主题生效后,其所有的类型都会生效,比如对于发短信策略,对应的类型有一分钟限制和半小时限制等.
  • strategy_type策略类型(作为redis中的key的一部分)
  • max_time过期时间内的限制次数
  • exseconds策略过期时间

基本过程:
对于使用策略的方法(以限制短信为例),先根据strategy_topic到redis中查找有没有针对该手机号的内容,如果没有就增加一条String记录.key为"strategy_type:手机号"如"smsphonemin:188",value为1,过期时间来自exseconds. 如果已经存在这个key,且value没有超过max_time,那么value就赠一并返回true,如果超过max_time,就返回false.

最初我写的方法是在需要防刷的方法里加上这条语句:

        if (!redisUtils.antiRush("sendsmsphone", userDto.getPhone())) {
            throw new CustomException(CommonCode.ANTIRUSH);
        }

但是对代码侵入性比较大, 后续维护不方便.于是改用注解+拦截器的方式,对于需要防刷的方法,直接加上注解,并将strategy_topic写入value即可.如果是短信防刷拦截器会自动从方法的参数里获取phone,如果是ip防刷拦截器会自动获取客户端IP地址.

先要创建spring boot项目, 配置redis服务器, 如果需要的话还有连接mysql数据库, 这里把连接数据库省略了, 策略内容直接存在代码里了
具体代码如下:

redis工具类, 包含对redis的操作:

package com.lin.utils;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.BoundListOperations;
import org.springframework.data.redis.core.BoundValueOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.TimeUnit;

/**
 * redis工具类,使用spring boot框架, 需要redis服务器并在配置文件中配置服务器信息
 */
@Component
public class RedisUtils {
    
    @Autowired
    RedisTemplate<String, String> redisTemplate;

    /**
     * 设置String类型的key,value, 无论存在不存在
     * @param key
     * @param value
     * @param EXseconds
     * @return
     */
    public String setValue(String key, String value, long EXseconds) {
        BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
        ops.set(value);
        ops.expire(EXseconds, TimeUnit.SECONDS);
        return key;
    }

    /**
     * 根据key,获取String类型的value值
     * @param key
     * @return
     */
    public String getValue(String key) {
        BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
        String value = ops.get();
        return value;
    }

    /**
     * 根据key,获取List类型的所有内容
     * @param key
     * @return
     */
    public List<String> getAllList(String key) {
        BoundListOperations<String, String> stringStringBoundListOperations = redisTemplate.boundListOps(key);
        List<String> range = stringStringBoundListOperations.range(0, -1);
        return range;
    }
    
	/**
     * 输入List类型的key,判断contain是否在该List中
     * @param key
     * @param contain
     * @return
     */
    public boolean containsInList(String key,String contain){
        return getAllList(key).contains(contain);
    }
    
    /**
     * 将key的值赠一
     * @param key
     * @return
     */
    public Long incrValue(String key) {
        BoundValueOperations<String, String> ops = redisTemplate.boundValueOps(key);
        Long increment = ops.increment();
        return increment;
    }

    /**
     * 判断是否存在key值
     * @param key
     * @return
     */
    public boolean haskey(String key) {
        return redisTemplate.hasKey(key);
    }
}

防刷工具类:

package com.lin.utils;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;

@Component
public class AntiRushUtils {

    @Autowired
    private RedisUtils redisUtils;

    /**
     * 单个策略: 先判断redis中有没有相应的值, 有: 如果超限就返回false, 没超限就自增一后返回true. 没有: set value 后返回true
     * @param strateType  策略类型 如:smsmin, smshalfhour, loginipmin, loginipmin等
     * @param strateContent 策略内容 如:手机号, api名+ip地址, api名+userid等, "strateType:strateContent"作为redis中的key
     * @param maxTimes 时间段内允许的最高访问次数
     * @param EXSeconds 时间段, 单位秒
     * @return true: 放行 false: 拦截
     */
    private boolean strategySingle(String strateType, String strateContent, Integer maxTimes, Long EXSeconds) {
        String key = strateType + ":" + strateContent;
        String value = redisUtils.getValue(key);
        if (value != null) {
            System.out.println("触发策略:"+key);
            if (Integer.parseInt(value) >= maxTimes) {
                // 超出最大值限制
                System.out.println("not pass:");
                return false;
            }else{
                // 计数器加一
                System.out.println("pass");
                redisUtils.incrValue(key);
                return true;
            }
        }
        // 内存尚无记录, 添加计数器, 开始计时
        redisUtils.setValue(key,"1",EXSeconds);
        return true;
    }

    /**
     * 多个策略组合使用, 循环遍历单个策略方法
     * 注意, 策略要按照EXSeconds 从小到大的顺序存入List中, 如: 按分钟计算就比按小时计算先放进list中
     * 如果大策略拦截了, 但是小策略可以计数, 如果用户狂刷的话, 有可能出现大策略放行, 但是小策略仍然有计数的情况,最终还是拦截
     * @param strategies 包含字符数组的List
     * 字符数组4项内容: strateType  策略类型   strateContent 策略内容   maxTimes 时间段内允许的最高访问次数   EXSeconds 时间段, 单位秒
     * @return true: 放行 false: 拦截
     */
    public boolean strategyLoop(List<String[]> strategies, String strategyContent){
        if (strategies == null){
        }
        for (String[] strategy : strategies) {
            String strategyType = strategy[0];
            int maxTimes = Integer.parseInt(strategy[1]);
            long EXSeconds = Long.parseLong(strategy[2]);
            if (!strategySingle(strategyType, strategyContent, maxTimes, EXSeconds)){
                return false;
            }
        }
        return true;
    }
}

防刷注解类:

package com.lin.annotations;

import java.lang.annotation.*;

/**
 * 防刷策略注解 用在需要防刷的controller层方法上
 * value值对应数据库中的strategy_topic,必填
 * 如果请求方式为GET,只取第一个参数
 * 多个策略可以使用多重注解, 要放在AntiRushChecks中使用
 */

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Repeatable(AntiRushChecks.class)
public @interface AntiRushCheck {

    boolean requried() default true;
    String value();
}

多重注解容器:

package com.lin.annotations;

import java.lang.annotation.*;

/**
 * java的多重注解容器(重复注解)
 * 多个注解存在该容器中,用法:
 *   @AntiRushChecks({
         @AntiRushCheck("sendsmsphone"),
         @AntiRushCheck("sendsmsip")
     })
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface AntiRushChecks {
    AntiRushCheck[] value();
}

防刷拦截器:

package com.lin.interceptors;

import com.alibaba.fastjson.JSONObject;
import com.lin.annotations.AntiRushCheck;
import com.lin.annotations.AntiRushChecks;
import com.lin.exception.CustomException;
import com.lin.response.CommonCode;
import com.lin.utils.AntiRushUtils;
import com.lin.utils.RedisUtils;
import com.lin.utils.SystemUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.StreamUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.*;

@Component
public class AntiRushInterceptor implements HandlerInterceptor {

    @Autowired
    RedisUtils redisUtils;
    @Autowired
    AntiRushUtils antiRushUtils;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        HandlerMethod handlerMethod = (HandlerMethod) handler;
        Method method = handlerMethod.getMethod();
        boolean mutilChecks = method.isAnnotationPresent(AntiRushChecks.class);
        boolean singleCheck = method.isAnnotationPresent(AntiRushCheck.class);
        if (mutilChecks){
            // 多重注解
            AntiRushChecks annotation = method.getAnnotation(AntiRushChecks.class);
            AntiRushCheck[] values = annotation.value();
            for (AntiRushCheck antiRushCheck : values) {
                oneAntiRushCheck(antiRushCheck, request);
            }
        }else if(singleCheck){
            // 单一注解
            oneAntiRushCheck(method.getAnnotation(AntiRushCheck.class), request);
        }
        return true;
    }

    private void oneAntiRushCheck(AntiRushCheck antiRushCheck, HttpServletRequest request) {

        String strategyTopic = antiRushCheck.value();  // 注解中的value值
        String strategyType = null;
        if (strategyTopic.endsWith("ip")) {
            // 获取客户端ip地址 具体方法略过请自行百度
            strategyType = "127.0.0.1" ;
        }else {
            // 目前只有两类防刷: 1.IP 2,手机号, 以下代码获取方法中的参数phone
            Map<String, String[]> parameterMap = request.getParameterMap();
            if (parameterMap.size()==0){
                // 请求类型是POST
                String requestBody = null;
                try {
                    byte[] bytes = StreamUtils.copyToByteArray(request.getInputStream());
                    requestBody = new String(bytes, request.getCharacterEncoding());
                } catch (Exception e) {
                    e.printStackTrace();
                }
                // 获取请求参数中的phone值
                strategyType = JSONObject.parseObject(requestBody).get("phone").toString();
                }
            else{
                // parameterMap不为空,说明请求类型是GET
                Collection<String[]> valuesList = parameterMap.values();
                for (String[] strs : valuesList) {
                    // 获取被拦截的方法的参数值, 只有一个参数
                    strategyType = strs[0];
                }
                if (strategyType == null) {
                    // 注解中的value值为空
                    throw new CustomException(CommonCode.INVALID_PARAM);
                }
            }
        }
        boolean send = antiRushUtils.strategyLoop(this.getStrategy().get(strategyTopic), strategyType);
        // 只要有一个策略不通过就直接抛出异常
        if (!send) {
            throw new CustomException(CommonCode.ANTIRUSH);
        }
    }

    /**
     * 获取具体防刷策略, 此处存在map中, 生产环境要从缓存或数据库中获取
     * @return
     */
    private Map<String, List> getStrategy() {
        Map<String, List> map = new HashMap<>();
        List<String[]> strategyModelsSms = new ArrayList<>();
        strategyModelsSms.add(new String[]{"smsphonemin", "2", "60"});
        strategyModelsSms.add(new String[]{"smsphonehalfhour", "5", "1800"});
        map.put("sendsmsphone", strategyModelsSms);
        List<String[]> strategyModelsSmsIP = new ArrayList<>();
        strategyModelsSmsIP.add(new String[]{"smsipmin", "2", "60"});
        strategyModelsSmsIP.add(new String[]{"smsiphalfhour", "5", "1800"});
        map.put("sendsmsip", strategyModelsSms);
        return map;
    }
}

由于拦截器获取POST请求request输入流后, 输入流指针指向末尾, 被拦截的方法无法获取参数内容, 需要将输入流拷贝并通过重写ServletInputStream的GetInputStream()方法对外提供数据.
我参考的是这个文章:
https://www.cnblogs.com/keeya/p/13634015.html#servletinputstreamcoyoteinputstream-%E8%BE%93%E5%85%A5%E6%B5%81%E6%97%A0%E6%B3%95%E9%87%8D%E5%A4%8D%E8%B0%83%E7%94%A8

涉及两个方法:

NewHttpServletRequestWrapper

package com.lin.entity;

import org.springframework.util.StreamUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;

/**
 * 自定义 HttpServletRequestWrapper 来包装输入流
 */
public class NewHttpServletRequestWrapper extends HttpServletRequestWrapper {

    /**
     * 缓存下来的HTTP body
     */
    private byte[] body;

    public NewHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        body = StreamUtils.copyToByteArray(request.getInputStream());
    }

    /**
     * 重新包装输入流
     * @return
     * @throws IOException
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        InputStream bodyStream = new ByteArrayInputStream(body);
        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bodyStream.read();
            }

            /**
             * 下面的方法一般情况下不会被使用,如果你引入了一些需要使用ServletInputStream的外部组件,可以重点关注一下。
             * @return
             */
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        InputStream bodyStream = new ByteArrayInputStream(body);
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }
}

NewDispatcherServlet

package com.lin.entity;

import org.springframework.web.servlet.DispatcherServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * 自定义 DispatcherServlet 来分派 NewHttpServletRequestWrapper
 */
public class NewDispatcherServlet extends DispatcherServlet{
    /**
     * 包装成我们自定义的request
     * @param request
     * @param response
     * @throws Exception
     */
    @Override
    protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
        super.doDispatch(new NewHttpServletRequestWrapper(request), response);
    }
}

Configuration类

package com.lin.interceptors;

import com.lin.entity.NewDispatcherServlet;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@Configuration
public class InterceptorConfig implements WebMvcConfigurer {

    @Autowired
    AntiRushInterceptor antiRushInterceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        // 声明自定义的拦截器对象和要拦截的请求路径
        registry.addInterceptor(antiRushInterceptor) .addPathPatterns("/api/**") ;
    }

    @Bean
    @Qualifier(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME)
    public DispatcherServlet dispatcherServlet() {
        return new NewDispatcherServlet();
    }
}

还需要一些常规的类:

  • 自定义异常类CustomException
  • 自定义返回值类ResponseResult
  • CommonCode类
  • -ResultCode接口
  • UserDto实体类用于传参,其中要有phone属性传递要发短信的手机号
    上述代码都属于常规代码,在此略过

接下来,在项目的controller类中需要使用防刷策略的方法上加上注解@AntiRushCheck(“策略主题”)就可以了.
如果想对一个方法实现多个策略,用@AntiRushChecks容器包含多个@AntiRushCheck即可,具体格式如下:

	@AntiRushChecks({
         @AntiRushCheck("sendsmsphone"),
         @AntiRushCheck("sendsmsip")
     })

头一次写这么长的, 希望多多提意见!

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值