实现接口幂等

实现接口幂等

接口幂等

1.基于 Token

基于 Token 这种方案的实现思路很简单,整个流程分两步:

  1. 客户端发送请求,从服务端获取一个 Token 令牌,每次请求获取到的都是一个全新的令牌。
  2. 客户端发送请求的时候,携带上第一步的令牌,处理请求之前,先校验令牌是否存在,当请求处理成功,就把令牌删除掉。

2. 基于请求参数的校验

拦截器方式

需要引入redis

使用一个拦截器实现,拦截需要接口幂等的接口方法;

判断是否是重复提交的数据

根据url + 参数 + 用户的身份认证token 为key判断幂等

1.搭建一个springboot 项目

pom.xml依赖:

web , redis , commons-lang3

 <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <!-- redis依赖-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
            <version>2.6.4</version>
        </dependency>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
        </dependency>
    </dependencies>

配置文件

application.yml

server:
  port: 8001

spring:
  redis:
    # 地址
    host: 127.0.0.1
    # 端口,默认为6379
    port: 6379
    # 密码
    # 连接超时时间
    timeout: 10s
    lettuce:
      pool:
        # 连接池中的最小空闲连接
        min-idle: 0
        # 连接池中的最大空闲连接
        max-idle: 8
        # 连接池的最大数据库连接数
        max-active: 8
        # #连接池最大阻塞等待时间(使用负值表示没有限制)
        max-wait: -1ms

配置redis配置类

RedisTemplateConfig

@EnableCaching
@Configuration
public class RedisTemplateConfig extends CachingConfigurerSupport {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);

        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        // key采用 String的序列化方式
        redisTemplate.setKeySerializer(stringRedisSerializer);
        // hash的 key也采用 String的序列化方式
        redisTemplate.setHashKeySerializer(stringRedisSerializer);
        // value序列化方式采用 jdk
        redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
        // hash的 value序列化方式采用 jdk
        redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer());
        redisTemplate.afterPropertiesSet();

        return redisTemplate;
    }

}

封装redis的缓存工具类

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.stereotype.Component;

import java.util.concurrent.TimeUnit;

/**
 * @author: wenyi
 * @create: 2022/12/8
 * @Description: 封装redis
 */
@Component
public class RedisCache {
    @Autowired
    public RedisTemplate redisTemplate;

    public <T> void setCacheObject(final String key, final T value, final Integer timeout, final TimeUnit timeUnit) {
        redisTemplate.opsForValue().set(key, value, timeout, timeUnit);
    }

    public <T> T getCacheObject(final String key) {
        ValueOperations<String, T> operation = redisTemplate.opsForValue();
        return operation.get(key);
    }
}

HttpServletRequest 读取一次后,后续就读取不了

构建可重复读取inputStream的request

public class RepeatedlyRequestWrapper extends HttpServletRequestWrapper {
    private final byte[] body;

    public RepeatedlyRequestWrapper(HttpServletRequest request, ServletResponse response) throws IOException {
        super(request);
        request.setCharacterEncoding(Constants.UTF8);
        response.setCharacterEncoding(Constants.UTF8);

        body = HttpHelper.getBodyString(request).getBytes(Constants.UTF8);
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public int available() throws IOException {
                return body.length;
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }
        };
    }
}

通用常量信息

public class Constants {
    /**
     * UTF-8 字符集
     */
    public static final String UTF8 = "UTF-8";
}

通用http工具封装

/**
 * 通用http工具封装
 *
 */
public class HttpHelper {
    private static final Logger LOGGER = LoggerFactory.getLogger(HttpHelper.class);

    public static String getBodyString(ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;
        try (InputStream inputStream = request.getInputStream()) {
            reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            LOGGER.warn("getBodyString出现问题!");
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    LOGGER.error(ExceptionUtils.getMessage(e));
                }
            }
        }
        return sb.toString();
    }
}
2.定义接口幂等注解

自定义一个注解,在需要进行幂等性处理的接口上,添加该注解即可,将来这个接口就会自动的进行幂等性处理。

import java.lang.annotation.*;

/**
 * 作用在接口方法上的 幂等性处理注解
 */
@Inherited
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RepeatSubmit {
    /**
     * 间隔时间(ms),小于此时间视为重复提交
     */
    int interval() default 5000;

    /**
     * 提示消息
     */
    String message() default "不允许重复提交,请稍候再试";

}
3.实现拦截器

定义抽象拦截器父类

接口方法拦截下来,然后找到接口上的 @RepeatSubmit 注解,调用 isRepeatSubmit 方法去判断是否是重复提交的数据

import com.fasterxml.jackson.databind.ObjectMapper;
import com.ung.myUrl.annotation.RepeatSubmit;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;

/**
 * @author: wenyi
 * @create: 2022/12/8
 * @Description: 抽象的拦截器
 */
public abstract class RepeatSubmitInterceptor implements HandlerInterceptor {
    /**
     * 将接口方法拦截下来,然后找到接口上的 @RepeatSubmit 注解,调用 isRepeatSubmit 方法去判断是否是重复提交的数据
     *
     * @param request
     * @param response
     * @param handler
     * @return
     * @throws Exception
     */
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (handler instanceof HandlerMethod) {
            HandlerMethod handlerMethod = (HandlerMethod) handler;
            Method method = handlerMethod.getMethod();
            RepeatSubmit annotation = method.getAnnotation(RepeatSubmit.class);
            if (annotation != null) {
                if (this.isRepeatSubmit(request, annotation)) {
                    Map<String, Object> map = new HashMap<>();
                    map.put("status", 500);
                    map.put("msg", annotation.message());
                    response.setContentType("application/json;charset=utf-8");
                    response.getWriter().write(new ObjectMapper().writeValueAsString(map));
                    return false;
                }
            }
            return true;
        } else {
            return true;
        }
    }

    /**
     * 验证是否重复提交由子类实现具体的防重复提交的规则
     *
     * @param request
     * @return
     * @throws Exception
     */
    public abstract boolean isRepeatSubmit(HttpServletRequest request, RepeatSubmit annotation);
}

子类实现父类的具体判断幂等规则

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ung.myUrl.annotation.RepeatSubmit;
import com.ung.myUrl.filter.RepeatedlyRequestWrapper;
import com.ung.myUrl.utils.RedisCache;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * @author: wenyi
 * @create: 2022/12/8
 * @Description: 判断请求url和数据是否和上一次相同,
 * * 如果和上次相同,则是重复提交表单。 有效时间为10秒内。
 */
@Component
public class SameUrlDataInterceptor extends RepeatSubmitInterceptor {
    public final String REPEAT_PARAMS = "repeatParams";

    public final String REPEAT_TIME = "repeatTime";
    public final static String REPEAT_SUBMIT_KEY = "REPEAT_SUBMIT_KEY";

    private String header = "Authorization";

    @Autowired
    private RedisCache redisCache;

    @SuppressWarnings("unchecked")
    @Override
    public boolean isRepeatSubmit(HttpServletRequest request, RepeatSubmit annotation) {
        String nowParams = "";
        if (request instanceof RepeatedlyRequestWrapper) {
            RepeatedlyRequestWrapper repeatedlyRequest = (RepeatedlyRequestWrapper) request;
            try {
                nowParams = repeatedlyRequest.getReader().readLine();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        // body参数为空,获取Parameter的数据
        if (StringUtils.isEmpty(nowParams)) {
            try {
                nowParams = new ObjectMapper().writeValueAsString(request.getParameterMap());
            } catch (JsonProcessingException e) {
                e.printStackTrace();
            }
        }
        Map<String, Object> nowDataMap = new HashMap<String, Object>();
        nowDataMap.put(REPEAT_PARAMS, nowParams);
        nowDataMap.put(REPEAT_TIME, System.currentTimeMillis());

        // 请求地址(作为存放cache的key值)
        String url = request.getRequestURI();

        // 唯一值(没有消息头则使用请求地址)
        String submitKey = request.getHeader(header);

        // 唯一标识(指定key + url + 消息头)
        String cacheRepeatKey = REPEAT_SUBMIT_KEY + url + submitKey;

        Object sessionObj = redisCache.getCacheObject(cacheRepeatKey);
        if (sessionObj != null) {
            Map<String, Object> sessionMap = (Map<String, Object>) sessionObj;
            if (compareParams(nowDataMap, sessionMap) && compareTime(nowDataMap, sessionMap, annotation.interval())) {
                return true;
            }
        }
        redisCache.setCacheObject(cacheRepeatKey, nowDataMap, annotation.interval(), TimeUnit.MILLISECONDS);
        return false;
    }

    /**
     * 判断参数是否相同
     */
    private boolean compareParams(Map<String, Object> nowMap, Map<String, Object> preMap) {
        String nowParams = (String) nowMap.get(REPEAT_PARAMS);
        String preParams = (String) preMap.get(REPEAT_PARAMS);
        return nowParams.equals(preParams);
    }

    /**
     * 判断两次间隔时间
     */
    private boolean compareTime(Map<String, Object> nowMap, Map<String, Object> preMap, int interval) {
        long time1 = (Long) nowMap.get(REPEAT_TIME);
        long time2 = (Long) preMap.get(REPEAT_TIME);
        if ((time1 - time2) < interval) {
            return true;
        }
        return false;
    }
}

最后配置拦截器

import com.ung.myUrl.interceptor.RepeatSubmitInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * @author: wenyi
 * @create: 2022/12/8
 * @Description: 配置拦截器
 */
@Configuration
public class WebConfig implements WebMvcConfigurer {

    @Autowired
    RepeatSubmitInterceptor repeatSubmitInterceptor;

     /**
     * 自定义拦截规则
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(repeatSubmitInterceptor)
                .addPathPatterns("/**");
    }
}
4.使用幂等注解

最后在controller使用

@RestController
public class HelloController {

    @PostMapping("/hello")
    @RepeatSubmit(interval = 100000)
    public String hello(@RequestBody String msg) {
        System.out.println("msg = " + msg);
        return "hello";
    }
}
5.测试使用

post请求 http://localhost:8001/hello

设置token ,这里不需要token认证,随便一个就行

在这里插入图片描述

最后请求成功,可以看到redis里面有数据,就是按照 指定key + URL + token

在这里插入图片描述

连续再次点击,会提示 不允许重复提交,说明幂等设置成功

AOP方式

实现防刷切面PreventAop.java

大致逻辑为:定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,将入参Base64编码+完整方法名作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis;每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,存在则拦截防刷,不存在则允许调用;

pom依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.6.11</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.ung</groupId>
    <artifactId>controller-test</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>controller-test</name>
    <description>Demo project for Spring Boot</description>
    <packaging>jar</packaging>
    <properties>
        <java.version>1.8</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <!--json以來-->
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.68</version>
        </dependency>

        <!--校验注解-->
        <dependency>
            <groupId>org.hibernate.validator</groupId>
            <artifactId>hibernate-validator</artifactId>
            <version>6.0.7.Final</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>

        <!-- redis依赖-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
            <version>2.6.4</version>
        </dependency>


    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

配置文件 配置redis
server:
  port: 8001
spring:
  redis:
    # 地址
    host: 127.0.0.1
    # 端口,默认为6379
    port: 6379
    # 密码
    # 连接超时时间
    timeout: 10s
    lettuce:
      pool:
        # 连接池中的最小空闲连接
        min-idle: 0
        # 连接池中的最大空闲连接
        max-idle: 8
        # 连接池的最大数据库连接数
        max-active: 8
        # #连接池最大阻塞等待时间(使用负值表示没有限制)
        max-wait: -1ms
Redis 配置
package com.ung.controllertest.config;

import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

@EnableCaching
@Configuration
public class RedisTemplateConfig extends CachingConfigurerSupport {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);

        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        // key采用 String的序列化方式
        redisTemplate.setKeySerializer(stringRedisSerializer);
        // hash的 key也采用 String的序列化方式
        redisTemplate.setHashKeySerializer(stringRedisSerializer);
        // value序列化方式采用 jdk
        redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
        // hash的 value序列化方式采用 jdk
        redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer());
        redisTemplate.afterPropertiesSet();

        return redisTemplate;
    }

}
定义 RedisUtil 操作的工具类
package com.ung.controllertest.utils;

import com.alibaba.fastjson.JSON;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.*;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.util.concurrent.TimeUnit;

/**
 * Redis工具类
 */
@Component
public class RedisUtil {
    @Autowired
    private RedisTemplate redisTemplate;
    @Resource(name="redisTemplate")
    private ValueOperations<String, String> valueOperations;
    @Resource(name="redisTemplate")
    private HashOperations<String, String, Object> hashOperations;
    @Resource(name="redisTemplate")
    private ListOperations<String, Object> listOperations;
    @Resource(name="redisTemplate")
    private SetOperations<String, Object> setOperations;
    @Resource(name="redisTemplate")
    private ZSetOperations<String, Object> zSetOperations;

    /**  默认过期时长,单位:秒 */
    public final static long DEFAULT_EXPIRE = 60 * 60 * 24;
    /**  不设置过期时长 */
    public final static long NOT_EXPIRE = -1;

    public void set(String key, Object value, long expire){
        valueOperations.set(key, toJson(value));
        if(expire != NOT_EXPIRE){
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
    }

    public void set(String key, Object value){
        set(key, value, DEFAULT_EXPIRE);
    }

    public <T> T get(String key, Class<T> clazz, long expire) {
        String value = valueOperations.get(key);
        if(expire != NOT_EXPIRE){
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value == null ? null : fromJson(value, clazz);
    }

    public <T> T get(String key, Class<T> clazz) {
        return get(key, clazz, NOT_EXPIRE);
    }

    public String get(String key, long expire) {
        String value = valueOperations.get(key);
        if(expire != NOT_EXPIRE){
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value;
    }

    public String get(String key) {
        return get(key, NOT_EXPIRE);
    }

    public void delete(String key) {
        redisTemplate.delete(key);
    }

    /**
     * Object转成JSON数据
     */
    private String toJson(Object object){
        if(object instanceof Integer || object instanceof Long || object instanceof Float ||
                object instanceof Double || object instanceof Boolean || object instanceof String){
            return String.valueOf(object);
        }
        return JSON.toJSONString(object);
    }

    /**
     * JSON数据,转成Object
     */
    private <T> T fromJson(String json, Class<T> clazz){
        return JSON.parseObject(json, clazz);
    }
}
定义注解Prevent
package com.ung.controllertest.annotation;

import com.ung.controllertest.enums.PreventStrategyEnum;

import java.lang.annotation.*;

/**
 * 接口防刷注解
 * 使用:
 * 在相应需要防刷的方法上加上
 * 该注解,即可
 */
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Prevent {

    /**
     * 限制的时间值(秒)
     *
     * @return
     */
    String value() default "60";

    /**
     * 提示
     */
    String message() default "";

    /**
     * 策略
     *
     * @return
     */
    PreventStrategyEnum strategy() default PreventStrategyEnum.DEFAULT;
}
定义枚举 设置策略
package com.ung.controllertest.enums;

/**
 * 防刷策略枚举
 */
public enum PreventStrategyEnum {
    /**
     * 默认(1分钟内不允许再次请求)
     */
    DEFAULT
}
实现防刷切面PreventAop
package com.ung.controllertest.aop;

import com.alibaba.fastjson.JSON;
import com.ung.controllertest.annotation.Prevent;
import com.ung.controllertest.enums.PreventStrategyEnum;
import com.ung.controllertest.exception.BusinessException;
import com.ung.controllertest.utils.RedisUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Base64;

/**
 * 大致逻辑为:
 * 定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,
 * 将入参Base64编码+完整方法名作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis;
 * 每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,
 * 存在则拦截防刷,不存在则允许调用;
 */
@Aspect
@Component
@Slf4j
public class PreventAop {

    @Autowired
    private RedisUtil redisUtil;

    /**
     * 切入点
     */
    @Pointcut("@annotation(com.ung.controllertest.annotation.Prevent)")
    public void pointcut() {
    }

    /**
     * 处理前
     *
     * @return
     */
    @Before("pointcut()")
    public void joinPoint(JoinPoint joinPoint) throws Exception {
        //获取ip
        HttpServletRequest request = (HttpServletRequest) joinPoint.getArgs()[0];
        String ip = getIp(request);
        System.out.println("请求的:" + ip);
        //获取请求参数
        String requestStr = null;
        //如果入参个数 !=1
        if (joinPoint.getArgs().length != 1) {
            for (int i = 1; i < joinPoint.getArgs().length; i++) {
                requestStr += JSON.toJSONString(joinPoint.getArgs()[i]);
            }
            if (StringUtils.isEmpty(requestStr) || requestStr.equalsIgnoreCase("{}")) {
                throw new BusinessException("[防刷]入参不允许为空");
            }
        }


        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method method = joinPoint.getTarget().getClass().getMethod(methodSignature.getName(),
                methodSignature.getParameterTypes());

        Prevent preventAnnotation = method.getAnnotation(Prevent.class);
        String methodFullName = method.getDeclaringClass().getName() + method.getName();

        entrance(preventAnnotation, ip + requestStr, methodFullName);
        return;
    }

    public String getIp(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 ||

                "unknown".equalsIgnoreCase(ip)) {

            ip = request.getHeader("Proxy-Client-IP");

        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {

            ip = request.getHeader("WL-Proxy-Client-IP");

        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {

            ip = request.getHeader("HTTP_CLIENT_IP");

        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {

            ip = request.getHeader("HTTP_X_FORWARDED_FOR");

        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {

            ip = request.getRemoteAddr();

        }
        return ip;
    }

    /**
     * 入口
     *
     * @param prevent
     * @param requestStr
     */
    private void entrance(Prevent prevent, String requestStr, String methodFullName) throws Exception {
        PreventStrategyEnum strategy = prevent.strategy();
        switch (strategy) {
            case DEFAULT:
                defaultHandle(requestStr, prevent, methodFullName);
                break;
            default:
                throw new BusinessException("无效的策略");
        }
    }

    /**
     * 默认处理方式
     *
     * @param requestStr
     * @param prevent
     */
    private void defaultHandle(String requestStr, Prevent prevent, String methodFullName) throws Exception {
        String base64Str = toBase64String(requestStr);
        long expire = Long.parseLong(prevent.value());

        String resp = redisUtil.get(methodFullName + base64Str);
        if (StringUtils.isEmpty(resp)) {
            redisUtil.set(methodFullName + base64Str, requestStr, expire);
        } else {
            String message = !StringUtils.isEmpty(prevent.message()) ? prevent.message() :
                    expire + "秒内不允许重复请求";
            throw new BusinessException(message);
        }
    }

    /**
     * 对象转换为base64字符串
     *
     * @param obj 对象值
     * @return base64字符串
     */
    private String toBase64String(String obj) throws Exception {
        if (StringUtils.isEmpty(obj)) {
            return null;
        }
        Base64.Encoder encoder = Base64.getEncoder();
        byte[] bytes = obj.getBytes("UTF-8");
        return encoder.encodeToString(bytes);
    }
}

controller接口使用注解

package com.ung.controllertest.controller;

import com.ung.controllertest.annotation.Prevent;
import com.ung.controllertest.pojo.dto.UserDto;
import com.ung.controllertest.service.UserService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;

import javax.servlet.http.HttpServletRequest;
import java.util.List;

@RestController
public class UserController {

    private UserService userService;

    @Autowired
    public UserController(UserService userService) {
        this.userService = userService;
    }

    @GetMapping("/user/getById/{id}")
    public UserDto getById(@PathVariable("id") Long id) {
        return userService.getById(id);
    }

    @GetMapping("/user/getList")
    @Prevent(value = "10",message = "10秒内不能重复请求")
    public List<UserDto> getList(HttpServletRequest request) {
        return userService.getList();
    }

    @PostMapping("/user/addUser")
    @Prevent(value = "10",message = "10秒内不能重复请求")
    public boolean addUser(HttpServletRequest request, @RequestBody UserDto userDto) {
        return userService.addUser(userDto);
    }

    @GetMapping("/user/getUserName/{id}")
    public String getUserName(@PathVariable("id") Long id) {
        return userService.getUserName(id);
    }
}

最后请求接口测试

http://localhost:8001/user/addUser

{
    "username": "username",
    "password": "123456",
    "email": "123@qq.com",
    "phone": "13900445200"
}

在这里插入图片描述

连续请求会拒绝

在这里插入图片描述

对没有请求参数的get请求也是一样的有效

http://localhost:8001/user/getList

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值