摘要: 大家知道, StringMVC中@RequestBody是读取的流的方式, 如果在之前有读取过流后, 发现就没有了.
我的Filter为了验证请求参数(包括Request Payload的数据)是否有非法符号(sql注入)
package com.ks.tow.common.filter;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.io.IOUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;
import com.alibaba.fastjson.JSONObject;
import com.ks.shop.util.secure.CheckSQLInjectionUtil;
import com.ks.tow.common.enums.HttpStatus;
import com.ks.tow.util.StringUtil;
/**
* 防sql注入攻击过滤
* @author LIU
*
*/
public class CheckSQLInjectionFilter implements Filter {
private List<String> excludes = new ArrayList<>();
public void setExcludes(List<String> excludes) {
this.excludes = excludes;
}
public List<String> getExcludes() {
return excludes;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
String excludes = filterConfig.getInitParameter("excludes");
if (StringUtil.isNotBlank(excludes)) {
String[] array = excludes.split(",");
for (String url : array) {
this.excludes.add(url);
}
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest)request;
HttpServletResponse resp = (HttpServletResponse)response;
String requestPath = req.getRequestURI();
requestPath = requestPath.substring(req.getContextPath().length() + 1);
while (requestPath.endsWith("/")){ //预防uri末尾有 ‘/’
requestPath = requestPath.substring(0, requestPath.length()-1);
}
for (String str : excludes) {
if (str.endsWith("*")) {
if (requestPath.startsWith(str.substring(0, str.length() - 1))){
chain.doFilter(req, resp);
return;
}
}
if(str.equals(requestPath)) {
chain.doFilter(req, resp);
return;
}
}
Map<String, Object> paramMap = new HashMap<>();
String type = req.getContentType();
ServletRequest requestWrapper = null;
if(req instanceof HttpServletRequest) {
requestWrapper = new ReaderReuseHttpServletRequestWrapper(req);
}
Reader reader = requestWrapper.getReader();
// 读取Request Payload数据
String Payload = IOUtils.toString(reader);
if (type != null && type.startsWith("application/json")){
JSONObject jsonObject = JSONObject.parseObject(Payload);
if (jsonObject != null) {
for(Map.Entry<String, Object> entry : jsonObject.entrySet()) {
paramMap.put(entry.getKey(), entry.getValue());
}
}
} else if(type != null && type.startsWith("text/plain")) {
String[] kvs = Payload.split("&");
for (String kv : kvs) {
String[] lf = kv.split("=");
paramMap.put(lf[0], lf[1]);
}
}
// 获取请求参数
Enumeration en = req.getParameterNames();
while(en.hasMoreElements()) {
String name = (String) en.nextElement();
String value = req.getParameter(name);
paramMap.put(name, value);
}
for(Map.Entry<String, Object> node : paramMap.entrySet()) {
boolean valid = true;
if (node.getValue() instanceof String)
valid = CheckSQLInjectionUtil.validate((String)node.getValue());
if (!valid) {
resp.setContentType("application/json;charset=UTF-8");
PrintWriter writer = resp.getWriter();
writer.write("{\"success\":false,\"msg\":\""+HttpStatus.SECURITY.getName()+"\",\"code\":"+HttpStatus.SECURITY.getCode()+"}");
writer.flush();
return;
}
}
chain.doFilter(requestWrapper, resp);
}
@Override
public void destroy() {
}
/**
* 两个方法都注明方法只能被调用一次,由于RequestBody是流的形式读取,
* 那么流读了一次就没有了,所以只能被调用一次。
* 既然是因为流只能读一次的原因,那么只要将流的内容保存下来,就可以实现反复读取了
* @author LIU
*
*/
public static class ReaderReuseHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final byte[] body;
public ReaderReuseHttpServletRequestWrapper(HttpServletRequest request)
throws IOException {
super(request);
body = IOUtils.toString(request.getReader()).getBytes(Charset.forName("UTF-8"));
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
};
}
}
}
请注意这里的编码, 最好将其转换成UTF-8的编码格式, 不然你获取到的中文则会使乱码的. 我自己也习惯于UTF-8的编码.
这样子就应该差不多了哦~
以下是校验sql注入的关键代码
public final class CheckSQLInjectionUtil {
private static final String sqlReg = "(?:')|(?:--)|(/\\*(?:.|[\\n\\r])*?\\*/)|"
+ "(\\b(select|update|and|or|delete|insert|trancate|char|into|substr|"
+ "ascii|declare|exec|count|master|into|drop|execute)\\b)";
private static Pattern pattern = Pattern.compile(sqlReg, Pattern.CASE_INSENSITIVE);
/**
* 检查SQL注入
* @param str
*/
public static boolean validate(String str) {
if (pattern.matcher(str).find()) {
return false;
}
return true;
}
/**
* 检查SQL注入
* @param strs
*/
public static boolean validate(String[] strs) {
for (String str : strs) {
if (pattern.matcher(str).find()) {
return false;
}
}
return true;
}
}