的思路:创建两个容器类来装载Request/Response->写一个过滤器Filter拦截请求将Info装载入容器中.
容器1:
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
import com.longshine.luxicrmboot.commons.utils.ApplicationUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.HtmlUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
/**
* Request包装类
* <p>
* 1.预防xss攻击
* 2.拓展requestbody无限获取(HttpServletRequestWrapper只能获取一次)
* </p>
*
* @author Caratacus
*/
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {
/**
* 存储requestBody byte[]
*/
private final byte[] body;
//这个容器是为了让@RequestParam能够绑定表单数据.
private final Map<String, String[]> parameterMap;
public RequestWrapper(HttpServletRequest request) {
super(request);
parameterMap = request.getParameterMap();
byte[] body = new byte[0];
try {
body = StreamUtils.copyToByteArray(request.getInputStream());
} catch (IOException e) {
log.error("Error: Get RequestBody byte[] fail," + e);
}
this.body = body;
}
@Override
public BufferedReader getReader() {
ServletInputStream inputStream = getInputStream();
return Objects.isNull(inputStream) ? null : new BufferedReader(new InputStreamReader(inputStream));
}
@Override
public ServletInputStream getInputStream() {
if (ObjectUtils.isEmpty(body)) {
return null;
}
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
@SuppressWarnings("EmptyMethod")
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() {
return bais.read();
}
};
}
@Override
public String[] getParameterValues(String name) {
String[] values = super.getParameterValues(name);
if (values == null) {
return null;
}
int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = htmlEscape(values[i]);
}
return encodedValues;
}
@Override
public String getParameter(String name) {
String value = super.getParameter(name);
if (value == null) {
return null;
}
return htmlEscape(value);
}
@Override
public Object getAttribute(String name) {
Object value = super.getAttribute(name);
if (value instanceof String) {
htmlEscape((String) value);
}
return value;
}
@Override
public String getHeader(String name) {
String value = super.getHeader(name);
if (value == null) {
return null;
}
return htmlEscape(value);
}
@Override
public String getQueryString() {
String value = super.getQueryString();
if (value == null) {
return null;
}
return htmlEscape(value);
}
@Override
public Enumeration<String> getParameterNames() {
Vector<String> vector = new Vector<String>(parameterMap.keySet());
return vector.elements();
}
/**
* 使用spring HtmlUtils 转义html标签达到预防xss攻击效果
*
* @param str
* @see org.springframework.web.util.HtmlUtils#htmlEscape
*/
protected String htmlEscape(String str) {
return HtmlUtils.htmlEscape(str);
}
}
容器2:
import com.alibaba.fastjson.JSON;
import com.google.common.base.Throwables;
import com.longshine.luxicrmboot.commons.msg.AjaxResult;
import com.longshine.luxicrmboot.commons.msg.ErrorCode;
import io.swagger.annotations.ApiResponses;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.MimeTypeUtils;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
/**
* response包装类
*
* @author Caratacus
*/
@Slf4j
public class ResponseWrapper extends HttpServletResponseWrapper {
private ErrorCode errorcode;
public ResponseWrapper(HttpServletResponse response) {
super(response);
}
public ResponseWrapper(HttpServletResponse response, ErrorCode errorcode) {
super(response);
setErrorCode(errorcode);
}
/**
* 获取ErrorCode
*
* @return
*/
public ErrorCode getErrorCode() {
return errorcode;
}
/**
* 设置ErrorCode
*
* @param errorCode
*/
public void setErrorCode(ErrorCode errorCode) {
if (Objects.nonNull(errorCode)) {
this.errorcode = errorCode;
super.setStatus(this.errorcode.getHttpCode());
}
}
/**
* 向外输出错误信息
*
* @param e
* @throws IOException
*/
public void writerErrorMsg(Exception e) {
if (Objects.isNull(errorcode)) {
log.warn("Warn: ErrorCodeEnum cannot be null, Skip the implementation of the method.");
return;
}
printWriterApiResponses(AjaxResult.failure(this.getErrorCode(), e));
}
/**
* 设置ApiErrorMsg
*/
public void writerErrorMsg() {
writerErrorMsg(null);
}
/**
* 向外输出AjaxResult
*
* @param ajaxResult
*/
public void printWriterApiResponses(AjaxResult ajaxResult) {
writeValueAsJson(ajaxResult);
}
/**
* 向外输出json对象
*
* @param obj
*/
public void writeValueAsJson(Object obj) {
if (super.isCommitted()) {
log.warn("Warn: Response isCommitted, Skip the implementation of the method.");
return;
}
super.setContentType(MimeTypeUtils.APPLICATION_JSON_VALUE);
super.setCharacterEncoding(StandardCharsets.UTF_8.name());
try (PrintWriter writer = super.getWriter()) {
writer.print(JSON.toJSONString(obj));
writer.flush();
} catch (IOException e) {
log.warn("Error: Response printJson faild, stackTrace: {}", Throwables.getStackTraceAsString(e));
}
}
}
过滤器:
import com.longshine.luxicrmboot.commons.wrapper.RequestWrapper;
import org.springframework.stereotype.Component;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
/**
* 记住Request/Response 过滤器
* 解决Request/Response不能重复使用问题
*
* @author Caratacus
*/
@Component
@WebFilter(filterName = "crownFilter", urlPatterns = "/*")
public class MemoryReqResFilter implements Filter {
@Override
@SuppressWarnings("EmptyMethod")
public void destroy() {
}
@Override
public void doFilter(ServletRequest request, ServletResponse res,
FilterChain chain) throws ServletException, IOException {
HttpServletRequest req = (HttpServletRequest) request;
chain.doFilter(new RequestWrapper(req), res);
}
@Override
@SuppressWarnings("EmptyMethod")
public void init(FilterConfig config) {
}
}