利用HttpServletRequestWrapper来支持可重复读取HttpServletRequest中的请求输入流且不影响Controller层的参数获取

HttpServletRequest中的请求输入流不可重复读取的原因就不叙述了,一堆搜索结果随便看,解决方案有2种,1是使用Spring MVC自带的类ContentCachingRequestWrapper,2是自定义请求包装器,方案2请看如下步骤正文:

1. 定义请求包装器,继承于 HttpServletRequestWrapper


import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.MediaType;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.net.URLDecoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * 自定义请求包装器,用于存储请求body数据,以解决请求输入流只能读取一次的问题
 */
@Slf4j
public class MyRequestWrapper extends HttpServletRequestWrapper {

    public static final Charset defaultCharset = StandardCharsets.UTF_8;

    private final Charset charset;

    private final byte[] bodyBytes;

    private final MultiValueMap<String, String> formParameters;

    public MyRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        String charEncode = request.getCharacterEncoding();
        String contentType = request.getContentType();
        if (StringUtils.isNotBlank(charEncode)) {
            charset = Charset.forName(charEncode);
        } else if (StringUtils.isNotBlank(contentType)
                && contentType.contains("charset=")) {
            charset = Charset.forName(contentType.substring(contentType.indexOf("=") + 1));
        } else {
            charset = defaultCharset;
        }
        bodyBytes = handleInputStream(request.getInputStream());
        if (StringUtils.isNotBlank(contentType)
                && contentType.contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE)) {
            formParameters = handleFormParameters();
        } else {
            formParameters = new LinkedMultiValueMap(0);
        }
    }

    private byte[] handleInputStream(InputStream inputStream) throws IOException {
        /*BufferedInputStream bis = new BufferedInputStream(inputStream);
        byte[] buffer = new byte[1024];
        int volume;
        byte[] target = new byte[0];
        while ((volume = bis.read(buffer)) != -1) {
            target = ArrayUtils.addAll(target, buffer);
        }
        bis.close();*/
        StringBuffer sb = new StringBuffer();
        BufferedReader br = new BufferedReader(new InputStreamReader(inputStream, charset));
        String line;
        while ((line = br.readLine()) != null) {
            sb.append(line);
        }
        br.close();
        byte[] target = sb.toString().getBytes(charset);
        return target;
    }

    public String getBody() {
        return new String(bodyBytes, this.charset);
    }

    /**
     * 参照HttpPutFormContentFilter
     */
    private MultiValueMap<String, String> handleFormParameters() throws UnsupportedEncodingException {
        String[] pairs = tokenizeToStringArray(getBody(), "&", true, true);
        MultiValueMap<String, String> result = new LinkedMultiValueMap(pairs.length);
        String[] var8 = pairs;
        int var9 = pairs.length;

        for(int var10 = 0; var10 < var9; ++var10) {
            String pair = var8[var10];
            int idx = pair.indexOf(61);
            if (idx == -1) {
                result.add(URLDecoder.decode(pair, charset.name()), null);
            } else {
                String name = URLDecoder.decode(pair.substring(0, idx), charset.name());
                String value = URLDecoder.decode(pair.substring(idx + 1), charset.name());
                result.add(name, value);
            }
        }

        return result;
    }

    private String[] tokenizeToStringArray(String str, String delimiters, boolean trimTokens, boolean ignoreEmptyTokens) {
        if (str == null) {
            return new String[0];
        } else {
            StringTokenizer st = new StringTokenizer(str, delimiters);
            ArrayList<String> tokens = new ArrayList<>();

            while(true) {
                String token;
                do {
                    if (!st.hasMoreTokens()) {
                        return tokens.toArray(new String[0]);
                    }

                    token = st.nextToken();
                    if (trimTokens) {
                        token = token.trim();
                    }
                } while(ignoreEmptyTokens && token.length() <= 0);

                tokens.add(token);
            }
        }
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        ByteArrayInputStream inputStream = new ByteArrayInputStream(bodyBytes);
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return inputStream.read();
            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public String getParameter(String name) {
        String queryStringValue = super.getParameter(name);
        String formValue = this.formParameters.getFirst(name);
        return queryStringValue != null ? queryStringValue : formValue;
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> result = new LinkedHashMap();
        Enumeration names = this.getParameterNames();

        while(names.hasMoreElements()) {
            String name = (String)names.nextElement();
            result.put(name, this.getParameterValues(name));
        }

        return result;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        Set<String> names = new LinkedHashSet();
        names.addAll(Collections.list(super.getParameterNames()));
        names.addAll(this.formParameters.keySet());
        return Collections.enumeration(names);
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] parameterValues = super.getParameterValues(name);
        List<String> formParam = this.formParameters.get(name);
        if (formParam == null) {
            return parameterValues;
        } else if (parameterValues != null && this.getQueryString() != null) {
            List<String> result = new ArrayList(parameterValues.length + formParam.size());
            result.addAll(Arrays.asList(parameterValues));
            result.addAll(formParam);
            return result.toArray(new String[0]);
        } else {
            return formParam.toArray(new String[0]);
        }
    }

}

2. 定义过滤器,用于将 HttpServletRequest 对象替换成自定义请求包装器 MyRequestWrapper 对象


import lombok.extern.slf4j.Slf4j;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * 在过滤器中将 HttpServletRequest 对象替换成自定义请求包装器 MyRequestWrapper 对象
 */
@Slf4j
public class MyRequestWrapperFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        log.debug("执行HttpServletRequestWrapper包装器替换");
        MyRequestWrapper requestWrapper = new MyRequestWrapper((HttpServletRequest) servletRequest);
        filterChain.doFilter(requestWrapper, servletResponse);
    }

    @Override
    public void destroy() {

    }
}

3. 定义拦截器,提前获取请求body数据


import com.alibaba.fastjson.JSON;
import com.me.config.filter.MyRequestWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.util.WebUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;

@Slf4j
public class AccessInterceptor implements HandlerInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        Map<String, String[]> parameterMap;
        String body;
        /*if (request instanceof MyRequestWrapper) {
            MyRequestWrapper requestWrapper = (MyRequestWrapper) request;
            body = requestWrapper.getBody();
            parameterMap = requestWrapper.getParameterMap();
        } else {
            body = "";
            parameterMap = request.getParameterMap();
        }*/
        MyRequestWrapper requestWrapper = WebUtils.getNativeRequest(request, MyRequestWrapper.class);
        if (requestWrapper != null) {
            body = requestWrapper.getBody();
            parameterMap = requestWrapper.getParameterMap();
        } else {
            body = StringUtils.EMPTY;
            parameterMap = request.getParameterMap();
        }
        log.info("clientURL----->{}", request.getRemoteAddr());
        log.info("requestURL----->{} {}", request.getMethod(), request.getRequestURL().toString());
        log.info("parameterMap----->{}", JSON.toJSONString(parameterMap));
        log.info("requestBody----->{}", body);
        return true;
    }
}

4. 在配置类中注册过滤器与拦截器,使其生效


import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.filter.CorsFilter;
import org.springframework.web.servlet.config.annotation.*;

@Slf4j
@Configuration
public class BaseConfig extends WebMvcConfigurationSupport {

	@Bean
	public MyRequestWrapperFilter getMyRequestWrapperFilter() {
		return new MyRequestWrapperFilter();
	}

	@Bean
	public FilterRegistrationBean someFilterRegistration() {
		FilterRegistrationBean registration = new FilterRegistrationBean();
		registration.setFilter(getMyRequestWrapperFilter());
		registration.addUrlPatterns("/*");
		registration.setName("myInputStreamFilter");
		// 设置过滤器执行优先级,数值越小优先级越高
		// 当前只要保证Filter过滤器在Interceptor拦截器之前执行即可,所以设置为最低优先级
		registration.setOrder(Ordered.LOWEST_PRECEDENCE);
		return registration;
	}

	@Bean
	public AccessInterceptor getAccessInterceptor() {
		return new AccessInterceptor();
	}

	@Override
	public void addInterceptors(InterceptorRegistry registry) {
		registry.addInterceptor(getAccessInterceptor())
				.addPathPatterns("/**")
				.excludePathPatterns("/static/**");
    }

	/**
	 * 解决跨域问题
	 */
	@Override
	public void addCorsMappings(CorsRegistry registry) {
		registry.addMapping("/**")
				.allowedOrigins("*")
				.allowedHeaders("Content-Type,X-Requested-With,Cookies,Cookie,X-Auth-Token,token,auth,Authorization")
				.allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
				.allowCredentials(true)
				.maxAge(3600);
	}
	@Bean
	public CorsFilter corsFilter() {
		UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
		CorsConfiguration corsConfiguration = new CorsConfiguration();
		// 请求常用的三种配置,*代表允许所有,当时你也可以自定义属性(比如header只能带什么,method只能是post方式等等)
		corsConfiguration.addAllowedOrigin("*");
		corsConfiguration.addAllowedHeader("Content-Type,X-Requested-With,Cookies,Cookie,X-Auth-Token,token,auth,Authorization");
		corsConfiguration.addAllowedMethod("GET,POST,PUT,DELETE,OPTIONS");
		corsConfiguration.setAllowCredentials(true);
		corsConfiguration.setMaxAge(3600L);
		source.registerCorsConfiguration("/**", corsConfiguration);
		return new CorsFilter(source);
	}

}

5. 定义Controller类,这里就不贴源码了,与正常写法无差异。

总结:最主要的在于请求包装器 MyRequestWrapper 中对输入流的处理和存储以及重写获取参数的几个方法,再加上过滤器中的请求包装器替换来保证后续输入流和参数的正常可重复获取。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
如果您的HttpServletRequest拦截器在读一次请求数据之后再次读时无法读数据,可能是因为HttpServletRequest对象的输入只能读一次,如果您已经读了它,那么它将不再可用。 为了解决这个问题,您可以将HttpServletRequest对象的请求数据读到一个字节数组,并将字节数组包装在一个新的HttpServletRequest对象,然后将新的HttpServletRequest对象用于后续处理。 以下是一个示例: ```java public class MyInterceptor implements HandlerInterceptor { public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { // 读请求数据到字节数组 byte[] requestBody = IOUtils.toByteArray(request.getInputStream()); // 创建新的HttpServletRequest对象,并将字节数组包装在里面 HttpServletRequestWrapper requestWrapper = new HttpServletRequestWrapper(request) { @Override public ServletInputStream getInputStream() throws IOException { return new ServletInputStreamWrapper(requestBody); } @Override public int getContentLength() { return requestBody.length; } @Override public long getContentLengthLong() { return requestBody.length; } }; // 将新的HttpServletRequest对象用于后续处理 // ... return true; } } ``` 在这个示例,我们使用了IOUtils.toByteArray()方法将HttpServletRequest对象的输入到一个字节数组。然后,我们创建了一个新的HttpServletRequestWrapper对象,并将字节数组包装在里面。最后,我们将新的HttpServletRequestWrapper对象用于后续处理。 这样,即使HttpServletRequest对象的输入只能读一次,您也可以在拦截器多次读HttpServletRequest对象的请求数据。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值