拦截器中读取request中的流后,controller 无法获取到数据解决方案

36 篇文章 0 订阅

一般我们会在InterceptorAdapter拦截器中对请求的token进行验证

如果是content-type 是 application/x-www-form-urlencoded  则没有什么问题

如果我们的接口是用@RequestBody来接受数据,那么我们在拦截器中验证token时

需要读取request的输入流  ,因为 ServletRequest中getReader()和getInputStream()只能调用一次

这样就会导致controller 无法拿到数据。


解决方法 :


自定义一个类 BodyReaderHttpServletRequestWrapper.java 


import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import com.microdata.core.util.Encodes;

public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {
	 private Map<String, String[]> paramsMap;

	    @Override
	    public Map getParameterMap() {
	        return paramsMap;
	    }

	    @Override
	    public String getParameter(String name) {// 重写getParameter,代表参数从当前类中的map获取
	        String[] values = paramsMap.get(name);
	        if (values == null || values.length == 0) {
	            return null;
	        }
	        return values[0];
	    }

	    @Override
	    public String[] getParameterValues(String name) {// 同上
	        return paramsMap.get(name);
	    }

	    @Override
	    public Enumeration getParameterNames() {
	        return Collections.enumeration(paramsMap.keySet());
	    }

	    private String getRequestBody(InputStream stream) {
	        String line = "";
	        StringBuilder body = new StringBuilder();
	        int counter = 0;

	        // 读取POST提交的数据内容
	        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
	        try {
	            while ((line = reader.readLine()) != null) {
	                if (counter > 0) {
	                    body.append("rn");
	                }
	                body.append(line);
	                counter++;
	            }
	        } catch (IOException e) {
	            e.printStackTrace();
	        }

	        return body.toString();
	    }

	    private HashMap<String, String[]> getParamMapFromPost(HttpServletRequest request) {

	        String body = "";
	        try {
	            body = getRequestBody(request.getInputStream());
	        } catch (IOException e) {
	            e.printStackTrace();
	        }
	        HashMap<String, String[]> result = new HashMap<String, String[]>();

	        if (null == body || 0 == body.length()) {
	            return result;
	        }

	        return parseQueryString(body);
	    }

	    // 自定义解码函数
	    private String decodeValue(String value) {
	        if (value.contains("%u")) {
	            return Encodes.urlDecode(value);
	        } else {
	            try {
	                return URLDecoder.decode(value, "UTF-8");
	            } catch (UnsupportedEncodingException e) {
	                return "";// 非UTF-8编码
	            }
	        }
	    }

	    public HashMap<String, String[]> parseQueryString(String s) {
	        String valArray[] = null;
	        if (s == null) {
	            throw new IllegalArgumentException();
	        }
	        HashMap<String, String[]> ht = new HashMap<String, String[]>();
	        StringTokenizer st = new StringTokenizer(s, "&");
	        while (st.hasMoreTokens()) {
	            String pair = (String) st.nextToken();
	            int pos = pair.indexOf('=');
	            if (pos == -1) {
	                continue;
	            }
	            String key = pair.substring(0, pos);
	            String val = pair.substring(pos + 1, pair.length());
	            if (ht.containsKey(key)) {
	                String oldVals[] = (String[]) ht.get(key);
	                valArray = new String[oldVals.length + 1];
	                for (int i = 0; i < oldVals.length; i++) {
	                    valArray[i] = oldVals[i];
	                }
	                valArray[oldVals.length] = decodeValue(val);
	            } else {
	                valArray = new String[1];
	                valArray[0] = decodeValue(val);
	            }
	            ht.put(key, valArray);
	        }
	        return ht;
	    }

	    private Map<String, String[]> getParamMapFromGet(HttpServletRequest request) {
	        return parseQueryString(request.getQueryString());
	    }

	    private final byte[] body; // 报文

	    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
	        super(request);
	        body = readBytes(request.getInputStream());

	        // 首先从POST中获取数据
	        if ("POST".equals(request.getMethod().toUpperCase())) {
	            paramsMap = getParamMapFromPost(this);
	        } else {
	            paramsMap = getParamMapFromGet(this);
	        }

	    }

	    @Override
	    public BufferedReader getReader() throws IOException {
	        return new BufferedReader(new InputStreamReader(getInputStream()));
	    }

	    @Override
	    public ServletInputStream getInputStream() throws IOException {
	        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
	        return new ServletInputStream() {

	            @Override
	            public int read() throws IOException {
	                return bais.read();
	            }

				@Override
				public boolean isFinished() {
					return false;
				}

				@Override
				public boolean isReady() {
					return false;
				}

				@Override
				public void setReadListener(ReadListener arg0) {

				}
	        };
	    }

	    private static byte[] readBytes(InputStream in) throws IOException {
	        BufferedInputStream bufin = new BufferedInputStream(in);
	        int buffSize = 1024;
	        ByteArrayOutputStream out = new ByteArrayOutputStream(buffSize);

	        byte[] temp = new byte[buffSize];
	        int size = 0;
	        while ((size = bufin.read(temp)) != -1) {
	            out.write(temp, 0, size);
	        }
	        bufin.close();

	        byte[] content = out.toByteArray();
	        return content;
	    }

}


自定义Filter   HttpServletRequestReplacedFilter.java

import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import com.microdata.core.request.BodyReaderHttpServletRequestWrapper;
 
public class HttpServletRequestReplacedFilter implements Filter {
	@Override
	public void destroy() {

	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response,
			FilterChain chain) throws IOException, ServletException {

		ServletRequest requestWrapper = null;
		if (request instanceof HttpServletRequest) {
			HttpServletRequest httpServletRequest = (HttpServletRequest) request;
			if ("POST".equals(httpServletRequest.getMethod().toUpperCase())
					&& httpServletRequest.getContentType().equalsIgnoreCase(
							"application/json; charset=utf-8")) {
				requestWrapper = new BodyReaderHttpServletRequestWrapper(
						(HttpServletRequest) request);
			}
		}

		if (requestWrapper == null) {
			chain.doFilter(request, response);
		} else {
			chain.doFilter(requestWrapper, response); 
		}
	}

	@Override
	public void init(FilterConfig arg0) throws ServletException {

	}

}


在web.xml 配置 


	<filter>
		<filter-name>HttpServletRequestReplacedFilter</filter-name>
		<filter-class>com.microdata.core.filter.HttpServletRequestReplacedFilter</filter-class>
		<init-param>
			<param-name>encoding</param-name>
			<param-value>utf-8</param-value>
		</init-param>
	</filter>
	<filter-mapping>
		<filter-name>HttpServletRequestReplacedFilter</filter-name>
		<url-pattern>/*</url-pattern>
	</filter-mapping>


Encodes 类

package com.microdata.core.util;

import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.net.URLEncoder;

import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringEscapeUtils;

/**
 * 封装各种格式的编码解码工具类.
 * 
 * 1.Commons-Codec的 hex/base64 编码
 * 2.自制的base62 编码
 * 3.Commons-Lang的xml/html escape
 * 4.JDK提供的URLEncoder
 * 
 */
public class Encodes {

	private static final String DEFAULT_URL_ENCODING = "UTF-8";
	private static final char[] BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz".toCharArray();

	/**
	 * Hex编码.
	 */
	public static String encodeHex(byte[] input) {
		return Hex.encodeHexString(input);
	}

	/**
	 * Hex解码.
	 */
	public static byte[] decodeHex(String input) {
		try {
			return Hex.decodeHex(input.toCharArray());
		} catch (DecoderException e) {
			throw Exceptions.unchecked(e);
		}
	}

	/**
	 * Base64编码.
	 */
	public static String encodeBase64(byte[] input) {
		return Base64.encodeBase64String(input);
	}

	/**
	 * Base64编码, URL安全(将Base64中的URL非法字符'+'和'/'转为'-'和'_', 见RFC3548).
	 */
	public static String encodeUrlSafeBase64(byte[] input) {
		return Base64.encodeBase64URLSafeString(input);
	}

	/**
	 * Base64解码.
	 */
	public static byte[] decodeBase64(String input) {
		return Base64.decodeBase64(input);
	}

	/**
	 * Base62编码。
	 */
	public static String encodeBase62(byte[] input) {
		char[] chars = new char[input.length];
		for (int i = 0; i < input.length; i++) {
			chars[i] = BASE62[(input[i] & 0xFF) % BASE62.length];
		}
		return new String(chars);
	}

	/**
	 * Html 转码.
	 */
	public static String escapeHtml(String html) {
		return StringEscapeUtils.escapeHtml4(html);
	}

	/**
	 * Html 解码.
	 */
	public static String unescapeHtml(String htmlEscaped) {
		return StringEscapeUtils.unescapeHtml4(htmlEscaped);
	}

	/**
	 * Xml 转码.
	 */
	public static String escapeXml(String xml) {
		return StringEscapeUtils.escapeXml(xml);
	}

	/**
	 * Xml 解码.
	 */
	public static String unescapeXml(String xmlEscaped) {
		return StringEscapeUtils.unescapeXml(xmlEscaped);
	}

	/**
	 * URL 编码, Encode默认为UTF-8.
	 */
	public static String urlEncode(String part) {
		try {
			return URLEncoder.encode(part, DEFAULT_URL_ENCODING);
		} catch (UnsupportedEncodingException e) {
			throw Exceptions.unchecked(e);
		}
	}

	/**
	 * URL 解码, Encode默认为UTF-8.
	 */
	public static String urlDecode(String part) {
		try {
			return URLDecoder.decode(part, DEFAULT_URL_ENCODING);
		} catch (UnsupportedEncodingException e) {
			throw Exceptions.unchecked(e);
		}
	}
}


  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
如果您的HttpServletRequest拦截器读取一次请求数据之后再次读取无法读取数据,可能是因为HttpServletRequest对象的输入流只能读取一次,如果您已经读取了它,那么它将不再可用。 为了解决这个问题,您可以将HttpServletRequest对象的请求数据读取到一个字节数组,并将字节数组包装在一个新的HttpServletRequest对象,然后将新的HttpServletRequest对象用于后续处理。 以下是一个示例: ```java public class MyInterceptor implements HandlerInterceptor { public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { // 读取请求数据到字节数组 byte[] requestBody = IOUtils.toByteArray(request.getInputStream()); // 创建新的HttpServletRequest对象,并将字节数组包装在里面 HttpServletRequestWrapper requestWrapper = new HttpServletRequestWrapper(request) { @Override public ServletInputStream getInputStream() throws IOException { return new ServletInputStreamWrapper(requestBody); } @Override public int getContentLength() { return requestBody.length; } @Override public long getContentLengthLong() { return requestBody.length; } }; // 将新的HttpServletRequest对象用于后续处理 // ... return true; } } ``` 在这个示例,我们使用了IOUtils.toByteArray()方法将HttpServletRequest对象的输入流读取到一个字节数组。然后,我们创建了一个新的HttpServletRequestWrapper对象,并将字节数组包装在里面。最后,我们将新的HttpServletRequestWrapper对象用于后续处理。 这样,即使HttpServletRequest对象的输入流只能读取一次,您也可以在拦截器多次读取HttpServletRequest对象的请求数据

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值