应用场景 :
因项目中接口请求时, 需要对请求参数进行签名验证。 当请求参数的body中有基本类型时(例: int, long, boolean等),因为基本类型如果没传值,序列化的时候会有默认值的问题, 最后导致实际接口调用生成的签名和项目中进行校验的签名不匹配。如果直接从request中获取请求参数body, 会出现request请求流重复读取异常,因此需要实现HttpServletRequestWrapper 重写getInputStream()和getReader()方法,将请求参数body复制到自己requestWrapper中, 后续只操作自己的requestWrapper
代码实现
启动类加上@ServletComponentScan 注解,开启Servlet、Filter、Listener可以直接通过@WebServlet、@WebFilter、@WebListener注解自动注册
//开启Servlet、Filter、Listener可以直接通过@WebServlet、@WebFilter、@WebListener注解自动注册
@ServletComponentScan
@SpringBootApplication(scanBasePackages = {"xxx.xxx.xx"})
public class StartApplication {
public static void main(String[] args) {
SpringApplication.run(StartApplication.class, args);
}
}
自定义filter, 必须将自定义的wrapper通过过滤器传下去, 不传不会调用重写后的getInputStream()和getReader()方法
@Component
@WebFilter(filterName = "RewriteRequestFilter", urlPatterns = "/*")
@Order(1)
public class RewriteRequestFilter implements Filter {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
//文件上传类型 不需要处理,否则会报java.nio.charset.MalformedInputException: Input length = 1异常
if (Objects.isNull(request) || Optional.ofNullable(request.getContentType()).orElse(StringUtils.EMPTY).startsWith("multipart/")) {
chain.doFilter(request, response);
return;
}
//自定义wrapper 处理流,必须在过滤器中处理,然后通过FilterChain传下去, 否则重写后的getInputStream()方法不会被调用
MyHttpServletRequestWrapper requestWrapper = new MyHttpServletRequestWrapper((HttpServletRequest)request);
chain.doFilter(requestWrapper,response);
}
}
自定义wrapper复制请求流, 如果不重写会报 java.lang.IllegalStateException: getInputStream() has already been called for this request 异常, 原因是request请求流不能重复读取。
@Slf4j
@Getter
public class MyHttpServletRequestWrapper extends HttpServletRequestWrapper {
/** 复制请求body */
private final String body;
public MyHttpServletRequestWrapper (HttpServletRequest request) {
super(request);
try {
//设置编码格式, 防止中文乱码
request.setCharacterEncoding("UTF-8");
//将请求中的流取出来放到body里,后面都只操作body就行
this.body = RequestReadUtils.read(request);
} catch (Exception e) {
log.error("MyHttpServletRequestWrapper exception", e);
throw new RuntimeException("MyHttpServletRequestWrapper 拦截器异常");
}
}
@Override
public ServletInputStream getInputStream() {
//返回body的流信息即可
try(final ByteArrayInputStream bais = new ByteArrayInputStream(body.getBytes())){
return getServletInputStream(bais);
}catch(IOException e){
log.error("MyHttpServletRequestWrapper.getInputStream() exception", e);
throw new RuntimeException("MyHttpServletRequestWrapper 获取input流异常");
}
}
@Override
public BufferedReader getReader(){
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
/**
* 重写getInputStream流
* @param bais
* @return
*/
private static ServletInputStream getServletInputStream(ByteArrayInputStream bais) {
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() {
return bais.read();
}
};
}
}
读取请求流工具类
@Slf4j
public class RequestReadUtils {
/**
* 读取请求流
* @param request
* @return
* @throws UnsupportedEncodingException
*/
public static String read(HttpServletRequest request){
try(BufferedReader reader = request.getReader()){
StringBuilder sb = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
sb.append(line);
}
return sb.toString();
}catch (Exception e){
log.error("MyHttpServletRequestWrapper.RequestReadUtils.readexception", e);
throw new RuntimeException("MyHttpServletRequestWrapper.RequestReadUtils.read 获取请求流异常");
}
}
}
自定义interceptor拦截器, 判断如果是自己的wrapper, 从wrapper中获取请求参数。只能在拦截器中获取请求参数,
1.如果在自定义的Filter中获取请求参数, restful风格的请求参数无法获取。
2.如果在AOP中获取请求参数, 获取的是序列化后的请求参数(基本类型默认值也会被获取)
@Slf4j
public class CommonInterceptor extends BaseInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
//如果是自定义wrapper, 从自定义wrapper中获取请求参数, 必须在interceptor拦截器中处理, 否则restful风格的请求参数获取不到。
if(request instanceof MyHttpServletRequestWrapper){
Map<String, Object> requestParam = getRequestParam((MyHttpServletRequestWrapper)request);
//放到ThreadLocal中, 这里可以根据自己的项目业务处理
CommonData.set("param", requestParam);
}
return true;
}
@Override
public void afterCompletion(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse,
Object o, Exception e) throws Exception {
//清空ThreadLocal, 防止内存泄漏
clearAllData();
}
private void clearAllData() {
CommonData.clearAll();
}
/**
* 从request获取参数
* @param request
* @return
* @throws IOException
*/
private Map<String, Object> getRequestParam(CallHttpServletRequestWrapper request){
Map<String, Object> paramMap = Maps.newHashMap();
//获取使用@RequestParam注解的参数
Map<String, String[]> parameterMap = request.getParameterMap();
if(!CollectionUtils.isEmpty(parameterMap)){
parameterMap.forEach((k,v)->{
if(Objects.nonNull(v) && v.length > 0){
paramMap.put(k, v[0]);
}
});
}
//获取restful请求参数,必须在interceptor拦截其中才能这样获取到restful参数
Object attribute = request.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE);
if(Objects.nonNull(attribute)){
Map<String, Object> attributeMap = (Map<String, Object>)attribute;
if(!CollectionUtils.isEmpty(attributeMap)){
paramMap.putAll(attributeMap);
}
}
//从自定义wrapper中, 获取body体参数
String bodyString = request.getBody();
if(StringUtils.isBlank(bodyString)){
return paramMap;
}
//解析body参数
Map<String, Object> bodyMap = parseRequestMap(bodyString);
if(CollectionUtils.isEmpty(bodyMap)){
return paramMap;
}
paramMap.putAll(bodyMap);
return paramMap;
}
/**
* 解析body请求参数
* @param bodyString
* @return
*/
private Map<String, Object> parseRequestMap(String bodyString) {
Map<String, Object> paramMap = Maps.newHashMap();
boolean validObject = JSONObject.isValidObject(bodyString);
//解析@ReqeustBody注解参数
if(validObject){
JSONObject jsonObject = JSONObject.parseObject(bodyString);
paramMap.putAll(jsonObject);
}else{
//解析url拼接参数 例 a=123&b=456, 没有加@RequestBoyd注解的post请求
String[] param = bodyString.split(SpecialConstant.AND);
if(param.length == 0){
return paramMap;
}
Stream.of(param).forEach(e->{
String[] split = e.split(SpecialConstant.EQ);
if(split.length == 0){
return;
}
paramMap.put(split[0], split[1]);
});
}
return paramMap;
}
}
自定义ThreadLocal工具类
@Slf4j
@Getter
@Setter
public class CommonData {
private static ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal<>();
/**
* 添加数据
* @param key
* @param value
*/
public static void set(String key, Object value) {
if (threadLocal.get() == null) {
Map<String, Object> map = new HashMap<>();
threadLocal.set(map);
}
threadLocal.get().put(key, value);
}
/**
* 清除数据
*/
public static void clearAll() {
threadLocal.set(null);
}
public static Map<String, Object> getSignParam() {
Object o = threadLocal.get().get("param");
if (Objects.isNull(o)) {
log.info("CommonData.getSignParam is null");
return null;
}
return (Map<String, Object>) o;
}
}
至此通过自定义wrapper重复读取request请求流的方式完成, 也不会再报 java.lang.IllegalStateException: getInputStream() has already been called for this request异常
最后感谢大家的阅读, 如问题请随时指出!!!