原理其实就是加一层过滤请求的过滤器,然后将非法参数根据你的业务去进行屏蔽;
注意事项:如果有接收xml格式的数据,则不要屏蔽<>符号,可以在过滤时判断接收数据类型~
1)首先创建一个继承HttpServletRequestWrapper的包装类,然后重写以下方法,并且加入到过滤器中(我写到一个类里了):
public class DefaultFilter implements Filter {
private final static Logger LOGGER = LoggerFactory.getLogger(DefaultFilter.class);
//这里在是注册filter时传入参数
private String xss;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
xss = filterConfig.getInitParameter("xss");
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
long start = System.currentTimeMillis();
HttpServletRequest request = (HttpServletRequest) servletRequest;
WebHttpUtils.logHttpRequest(request);
//判断是否使用xss过滤,或者你可以传入其他有用信息
if (xss != null) {
filterChain.doFilter(new CheckXssRequestWrapper(request), servletResponse);
} else {
filterChain.doFilter(request, servletResponse);
}
LOGGER.info("请求地址:{}的请求时间为:{}秒", ((HttpServletRequest) servletRequest).getRequestURI(),
new BigDecimal(System.currentTimeMillis() - start).divide(new BigDecimal(1000)).setScale(2, RoundingMode.HALF_UP));
}
@Override
public void destroy() {
}
private static class CheckXssRequestWrapper extends HttpServletRequestWrapper {
//拿到所有表单请求的参数及对应值数组
Map<String, String[]> parameterMap;
//拿到存在body体中的输入流,比如接收json
ServletInputStream inputStream;
//构造方法
public CheckXssRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
this.parameterMap = request.getParameterMap();
this.inputStream = request.getInputStream();
this.handlerForm();
this.handlerBody();
}
//获取流
@Override
public ServletInputStream getInputStream() throws IOException {
return this.inputStream;
}
@Override
public String getParameter(String name) {
return this.parameterMap.get(name) == null ? null : this.parameterMap.get(name)[0];
}
//在项目中,比如springboot的控制层接收对象表单会走这个方法
@Override
public String[] getParameterValues(String name) {
return this.parameterMap.get(name);
}
//处理表单数据
private void handlerForm() {
if (parameterMap != null && parameterMap.size() > 0) {
parameterMap.forEach((k, v) -> {
if (v.length > 0) {
for (int i = 0, length = v.length; i < length; i++) {
v[i] = filterParameter(v[i]);
}
}
});
}
}
//处理请求体中的数据
private void handlerBody() throws IOException {
if (inputStream != null) {
StringBuffer str = new StringBuffer();
byte[] b = new byte[1024];
int i = 0;
while ((i = inputStream.read(b)) != -1) {
str.append(new String(b, 0, i));
}
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(filterParameter(str.toString()).getBytes());
inputStream = new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
};
}
}
//通用方法,加入你想过滤的字符串
private static final String[] dangerCharacters = {"<", ">"};
private String filterParameter(String parameter) {
if (StringUtils.isEmpty(parameter)) return parameter;
StringBuffer sb = new StringBuffer(parameter);
for (String s : dangerCharacters) {
while (sb.indexOf(s) > -1) {
sb.deleteCharAt(sb.indexOf(s));
}
}
return sb.toString();
}
}
}
2)注册filter
@Bean
public FilterRegistrationBean filterRegistrationBean() {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(new DefaultFilter());
bean.setUrlPatterns(Arrays.asList("/*"));
bean.setName("defaultFilter");
bean.setInitParameters(new HashMap() {{
put("xss", "yes");
}});
return bean;
}