【SpringBoot】获取request请求参数,多次读取报错问题 (has already been called for this request)

应用场景 :

        因项目中接口请求时, 需要对请求参数进行签名验证。 当请求参数的body中有基本类型时(例: int, long, boolean等),因为基本类型如果没传值,序列化的时候会有默认值的问题, 最后导致实际接口调用生成的签名和项目中进行校验的签名不匹配。如果直接从request中获取请求参数body, 会出现request请求流重复读取异常,因此需要实现HttpServletRequestWrapper 重写getInputStream()和getReader()方法,将请求参数body复制到自己requestWrapper中, 后续只操作自己的requestWrapper

代码实现

启动类加上@ServletComponentScan 注解,开启Servlet、Filter、Listener可以直接通过@WebServlet、@WebFilter、@WebListener注解自动注册

//开启Servlet、Filter、Listener可以直接通过@WebServlet、@WebFilter、@WebListener注解自动注册
@ServletComponentScan 
@SpringBootApplication(scanBasePackages = {"xxx.xxx.xx"})
public class StartApplication {
    public static void main(String[] args) {
        SpringApplication.run(StartApplication.class, args);
    }
}

自定义filter, 必须将自定义的wrapper通过过滤器传下去, 不传不会调用重写后的getInputStream()和getReader()方法

@Component
@WebFilter(filterName = "RewriteRequestFilter", urlPatterns = "/*")
@Order(1)
public class RewriteRequestFilter implements Filter {
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        //文件上传类型 不需要处理,否则会报java.nio.charset.MalformedInputException: Input length = 1异常
        if (Objects.isNull(request) || Optional.ofNullable(request.getContentType()).orElse(StringUtils.EMPTY).startsWith("multipart/")) {
            chain.doFilter(request, response);
            return;
        }
        //自定义wrapper 处理流,必须在过滤器中处理,然后通过FilterChain传下去, 否则重写后的getInputStream()方法不会被调用
        MyHttpServletRequestWrapper requestWrapper = new MyHttpServletRequestWrapper((HttpServletRequest)request);
        chain.doFilter(requestWrapper,response);
    }
}

自定义wrapper复制请求流, 如果不重写会报 java.lang.IllegalStateException: getInputStream() has already been called for this request 异常, 原因是request请求流不能重复读取。

@Slf4j
@Getter
public class MyHttpServletRequestWrapper extends HttpServletRequestWrapper {
    /** 复制请求body */
    private final String body;

    public MyHttpServletRequestWrapper (HttpServletRequest request) {
        super(request);
        try {
            //设置编码格式, 防止中文乱码
            request.setCharacterEncoding("UTF-8");
            //将请求中的流取出来放到body里,后面都只操作body就行
            this.body = RequestReadUtils.read(request);
        } catch (Exception e) {
            log.error("MyHttpServletRequestWrapper exception", e);
            throw new RuntimeException("MyHttpServletRequestWrapper 拦截器异常");
        }
    }

    @Override
    public ServletInputStream getInputStream()  {
        //返回body的流信息即可
        try(final ByteArrayInputStream bais = new ByteArrayInputStream(body.getBytes())){
            return getServletInputStream(bais);
        }catch(IOException e){
            log.error("MyHttpServletRequestWrapper.getInputStream() exception", e);
            throw new RuntimeException("MyHttpServletRequestWrapper 获取input流异常");
        }
    }

    @Override
    public BufferedReader getReader(){
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    /**
     * 重写getInputStream流
     * @param bais
     * @return
     */
    private static ServletInputStream getServletInputStream(ByteArrayInputStream bais) {
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
            @Override
            public int read() {
                return bais.read();
            }
        };
    }
}

读取请求流工具类

@Slf4j
public class RequestReadUtils {
    /**
     * 读取请求流
     * @param request
     * @return
     * @throws UnsupportedEncodingException
     */
    public static String read(HttpServletRequest request){
        try(BufferedReader reader = request.getReader()){
            StringBuilder sb = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
            return sb.toString();
        }catch (Exception e){
            log.error("MyHttpServletRequestWrapper.RequestReadUtils.readexception", e);
            throw new RuntimeException("MyHttpServletRequestWrapper.RequestReadUtils.read 获取请求流异常");
        }
    }
}

自定义interceptor拦截器, 判断如果是自己的wrapper, 从wrapper中获取请求参数。只能在拦截器中获取请求参数,

1.如果在自定义的Filter中获取请求参数, restful风格的请求参数无法获取。

2.如果在AOP中获取请求参数, 获取的是序列化后的请求参数(基本类型默认值也会被获取) 

@Slf4j
public class CommonInterceptor extends BaseInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
        //如果是自定义wrapper, 从自定义wrapper中获取请求参数, 必须在interceptor拦截器中处理, 否则restful风格的请求参数获取不到。
        if(request instanceof MyHttpServletRequestWrapper){
            Map<String, Object> requestParam = getRequestParam((MyHttpServletRequestWrapper)request);
            //放到ThreadLocal中, 这里可以根据自己的项目业务处理
            CommonData.set("param", requestParam);
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
        Object o, Exception e) throws Exception {
        //清空ThreadLocal, 防止内存泄漏
        clearAllData();
    }

    private void clearAllData() {
        CommonData.clearAll();
    }


    /**
     * 从request获取参数
     * @param request
     * @return
     * @throws IOException
     */
    private Map<String, Object> getRequestParam(CallHttpServletRequestWrapper request){
        Map<String, Object> paramMap = Maps.newHashMap();

        //获取使用@RequestParam注解的参数
        Map<String, String[]> parameterMap = request.getParameterMap();
        if(!CollectionUtils.isEmpty(parameterMap)){
            parameterMap.forEach((k,v)->{
                if(Objects.nonNull(v) && v.length > 0){
                    paramMap.put(k, v[0]);
                }
            });
        }

        //获取restful请求参数,必须在interceptor拦截其中才能这样获取到restful参数
        Object attribute = request.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE);
        if(Objects.nonNull(attribute)){
            Map<String, Object> attributeMap = (Map<String, Object>)attribute;
            if(!CollectionUtils.isEmpty(attributeMap)){
                paramMap.putAll(attributeMap);
            }
        }
        //从自定义wrapper中, 获取body体参数
        String bodyString = request.getBody();
        if(StringUtils.isBlank(bodyString)){
            return paramMap;
        }

        //解析body参数
        Map<String, Object> bodyMap = parseRequestMap(bodyString);
        if(CollectionUtils.isEmpty(bodyMap)){
            return paramMap;
        }
        paramMap.putAll(bodyMap);
        return paramMap;
    }

    /**
     * 解析body请求参数
     * @param bodyString
     * @return
     */
    private Map<String, Object> parseRequestMap(String bodyString) {
        Map<String, Object> paramMap = Maps.newHashMap();
        boolean validObject = JSONObject.isValidObject(bodyString);
        //解析@ReqeustBody注解参数
        if(validObject){
            JSONObject jsonObject = JSONObject.parseObject(bodyString);
            paramMap.putAll(jsonObject);
        }else{
            //解析url拼接参数 例 a=123&b=456, 没有加@RequestBoyd注解的post请求
            String[] param = bodyString.split(SpecialConstant.AND);
            if(param.length == 0){
                return paramMap;
            }
            Stream.of(param).forEach(e->{
                String[] split = e.split(SpecialConstant.EQ);
                if(split.length == 0){
                    return;
                }
                paramMap.put(split[0], split[1]);
            });
        }
        return paramMap;
    }
}

自定义ThreadLocal工具类

@Slf4j
@Getter
@Setter
public class CommonData {
    private static ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<>();

     /**
     * 添加数据
     * @param key
     * @param value
     */
    public static void set(String key, Object value) {
        if (threadLocal.get() == null) {
            Map<String, Object> map = new HashMap<>();
            threadLocal.set(map);
        }
        threadLocal.get().put(key, value);
    }

   
     /**
     * 清除数据
     */
    public static void clearAll() {
        threadLocal.set(null);
    }
 

    public static Map<String, Object> getSignParam() {
        Object o = threadLocal.get().get("param");
        if (Objects.isNull(o)) {
            log.info("CommonData.getSignParam is null");
            return null;
        }
        return (Map<String, Object>) o;
    }
}

至此通过自定义wrapper重复读取request请求流的方式完成, 也不会再报 java.lang.IllegalStateException: getInputStream() has already been called for this request异常

最后感谢大家的阅读, 如问题请随时指出!!!

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值