一、SpringBoot针对富文本和非富文本添加xss过滤(如果富文本字段是唯一,这里的唯一是不跟非富文本字段同名,实际写一个HttpServletRequestWrapper就行)
1.xss过滤器
package com.doctortech.tmc.filter;
import com.doctortech.tmc.support.xss.XssHttpServletRequestWrapper;
import com.doctortech.tmc.support.xss.XssRichTextHttpServletRequestWrapper;
import org.springframework.stereotype.Component;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
@WebFilter
@Component
public class XssFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) servletRequest;
String path = ((HttpServletRequest) servletRequest).getServletPath();
String[] exclusionsUrls = {"/fileUpload/upload","/fileUpload/upload/img"};
String[] richTextUrls = {"/admin/adminArticle/add", "/admin/adminArticle/update", "/admin/adminPolicy/add"};
XssRichTextHttpServletRequestWrapper xssAndSqlHttpServletRequestWrapper = new XssRichTextHttpServletRequestWrapper(req);
for (String str : exclusionsUrls) {
if (path.contains(str)) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}
}
for (String rtu : richTextUrls) {
if (path.contains(rtu)) {
filterChain.doFilter(xssAndSqlHttpServletRequestWrapper, servletResponse);
return;
}
}
XssHttpServletRequestWrapper xssRequestWrapper = new XssHttpServletRequestWrapper(req);
filterChain.doFilter(xssRequestWrapper, servletResponse);
}
@Override
public void destroy() {
}
}
2.针对富文本接口进行过滤
package com.doctortech.tmc.support.xss;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
public class XssRichTextHttpServletRequestWrapper extends HttpServletRequestWrapper {
private static String sqlKey = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
private static String richTextKey = "content";
private static Set<String> notAllowedKeyWords = new HashSet<>(0);
private static Set<String> richTextKeySet = new HashSet<>(0);
static {
String[] keyStr = sqlKey.split("\\|");
for (String str : keyStr) {
notAllowedKeyWords.add(str);
}
}
static {
String[] keyStr = richTextKey.split("\\|");
for (String str : keyStr) {
richTextKeySet.add(str);
}
}
public XssRichTextHttpServletRequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public String getParameter(String name) {
String value = super.getParameter(name);
if (!StringUtils.isEmpty(value)) {
value = cleanXSS(value);
value = cleanSqlKeyWords(value);
}
return value;
}
@Override
public String[] getParameterValues(String name) {
String[] parameterValues = super.getParameterValues(name);
if (parameterValues == null) {
return null;
}
for (int i = 0; i < parameterValues.length; i++) {
String value = parameterValues[i];
parameterValues[i] = cleanXSS(value);
parameterValues[i] = cleanSqlKeyWords(parameterValues[i]);
}
return parameterValues;
}
@Override
public String getHeader(String name) {
String value = cleanXSS(super.getHeader(name));
if (value == null){
return null;
}
return cleanSqlKeyWords(value);
}
@Override
public String getQueryString() {
return cleanXSS(super.getQueryString());
}
@Override
public ServletInputStream getInputStream() throws IOException {
String bodyStr = getRequestBody(super.getInputStream());
if ("".equals(bodyStr)) {
return new ServletInputStream() {
@Override
public int read() throws IOException {
return 0;
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bodyStr.getBytes());
return new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
private String getRequestBody(InputStream stream) {
String line = "";
StringBuilder body = new StringBuilder();
int counter = 0;
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName("UTF-8")));
try {
while ((line = reader.readLine()) != null) {
body.append(line);
counter++;
}
} catch (IOException e) {
e.printStackTrace();
}
if (body == null) {
return "";
}
String data = transJsonNode(body.toString());
return data;
}
private String cleanXSS(String valueP) {
if (StringUtils.isBlank(valueP)) {
return "";
}
String value = valueP.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
value = value.replaceAll("<", "<").replaceAll(">", ">");
value = value.replaceAll("eval\\((.*)\\)", "");
value = value.replaceAll("alert", "");
value = cleanSqlKeyWords(value);
return value;
}
private String cleanRichTextXSS(String valueP) {
if (StringUtils.isBlank(valueP)) {
return "";
}
String value = valueP.replaceAll("eval\\((.*)\\)", "");
value = value.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
value = value.replaceAll("alert", "");
value = cleanSqlKeyWords(value);
return value;
}
private String cleanSqlKeyWords(String value) {
String paramValue = value;
return paramValue;
}
private String transJsonNode(String jsonStr) {
String str = "";
try {
ObjectMapper objectMapper = new ObjectMapper();
JsonNode jsonNode = objectMapper.readTree(jsonStr);
str = objectMapper.writeValueAsString(cleanJsonNodeXSS(jsonNode));
} catch (JsonProcessingException e) {
e.printStackTrace();
}
return str;
}
private Object cleanJsonNodeXSS(JsonNode jsonNode) {
Iterator<Map.Entry<String, JsonNode>> fields = jsonNode.fields();
if (!fields.hasNext()) {
String value = jsonNode.asText();
return cleanXSS(value);
}
Map<String, Object> map = new HashMap<>();
while(fields.hasNext()) {
Map.Entry<String, JsonNode> next = fields.next();
if (next.getValue().isTextual()) {
String value = next.getValue().asText();
String key = next.getKey();
String str = "";
if (richTextKeySet.contains(key)) {
str = cleanRichTextXSS(value);
} else {
str = cleanXSS(value);
}
map.put(next.getKey(),str);
}else if (next.getValue().isObject()){
map.put(next.getKey(),cleanJsonNodeXSS(next.getValue()));
}else if(next.getValue().isArray()) {
List<Object> elementList = new ArrayList<>();
Iterator<JsonNode> elements = next.getValue().elements();
while (elements.hasNext()) {
JsonNode childrenNext = elements.next();
Object nodeMap = cleanJsonNodeXSS(childrenNext);
elementList.add(nodeMap);
}
map.put(next.getKey(),elementList);
}else {
map.put(next.getKey(),next.getValue());
}
}
return map;
}
}
2.针对非富文本接口过滤
package com.doctortech.tmc.support.xss;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
private static String key = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
private static Set<String> notAllowedKeyWords = new HashSet<String>(0);
static {
String[] keyStr = key.split("\\|");
for (String str : keyStr) {
notAllowedKeyWords.add(str);
}
}
public XssHttpServletRequestWrapper(HttpServletRequest servletRequest) throws IOException {
super(servletRequest);
}
@Override
public ServletInputStream getInputStream() throws IOException {
String bodyStr = getRequestBody(super.getInputStream());
if ("".equals(bodyStr)) {
return new ServletInputStream() {
@Override
public int read() throws IOException {
return 0;
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bodyStr.getBytes());
return new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
private String getRequestBody(InputStream stream) {
String line = "";
StringBuilder body = new StringBuilder();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName("UTF-8")));
try {
while ((line = reader.readLine()) != null) {
body.append(line);
}
} catch (IOException e) {
e.printStackTrace();
}
if (body == null) {
return "";
}
String data = transJsonNode(body.toString());
return data;
}
private static String xssEncode(String s, int type) {
if (s == null || s.isEmpty()) {
return s;
}
StringBuilder sb = new StringBuilder(s.length() + 16);
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
if (type == 0) {
switch (c) {
case '\'':
sb.append('‘');
break;
case '\"':
sb.append('“');
break;
case '>':
sb.append('>');
break;
case '<':
sb.append('<');
break;
case '&':
sb.append('&');
break;
case '\\':
sb.append('\');
break;
case '#':
sb.append('#');
break;
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
} else {
switch (c) {
case '>':
sb.append('>');
break;
case '<':
sb.append('<');
break;
case '&':
sb.append('&');
break;
case '#':
sb.append('#');
break;
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
}
}
return sb.toString();
}
public static void processUrlEncoder(StringBuilder sb, String s, int index) {
if (s.length() >= index + 2) {
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'c' || s.charAt(index + 2) == 'C')) {
sb.append('<');
return;
}
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '0') {
sb.append('<');
return;
}
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'e' || s.charAt(index + 2) == 'E')) {
sb.append('>');
return;
}
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '2') {
sb.append('>');
return;
}
}
sb.append(s.charAt(index));
}
@Override
public String getParameter(String parameter) {
String value = super.getParameter(parameter);
if (value == null) {
return null;
}
return cleanXSS(value);
}
@Override
public String[] getParameterValues(String parameter) {
String[] values = super.getParameterValues(parameter);
if (values == null) {
return null;
}
int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = cleanXSS(values[i]);
}
return encodedValues;
}
@Override
public Map<String, String[]> getParameterMap(){
Map<String, String[]> values = super.getParameterMap();
if (values == null) {
return null;
}
Map<String, String[]> result = new HashMap<>();
for(String key:values.keySet()){
String encodedKey = cleanXSS(key);
int count = values.get(key).length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++){
encodedValues[i] = cleanXSS(values.get(key)[i]);
}
result.put(encodedKey,encodedValues);
}
return result;
}
@Override
public String getHeader(String name) {
String value = super.getHeader(name);
if (value == null) {
return null;
}
return cleanXSS(value);
}
private String cleanXSS(String valueP) {
String value = valueP.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
value = value.replaceAll("<", "<").replaceAll(">", ">");
value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
value = value.replaceAll("eval\\((.*)\\)", "");
value = value.replaceAll("alert", "");
value = cleanSqlKeyWords(value);
return value;
}
private String cleanSqlKeyWords(String value) {
String paramValue = value;
return paramValue;
}
private String transJsonNode(String jsonStr) {
String str = "";
try {
ObjectMapper objectMapper = new ObjectMapper();
JsonNode jsonNode = objectMapper.readTree(jsonStr);
str = objectMapper.writeValueAsString(cleanJsonNodeXSS(jsonNode));
} catch (JsonProcessingException e) {
e.printStackTrace();
}
return str;
}
private Object cleanJsonNodeXSS(JsonNode jsonNode) {
Iterator<Map.Entry<String, JsonNode>> fields = jsonNode.fields();
if (!fields.hasNext()) {
String value = jsonNode.asText();
return cleanXSS(value);
}
Map<String, Object> map = new HashMap<>();
while(fields.hasNext()) {
Map.Entry<String, JsonNode> next = fields.next();
if (next.getValue().isTextual()) {
String value = next.getValue().asText();
String str = cleanXSS(value);
map.put(next.getKey(),str);
}else if (next.getValue().isObject()){
map.put(next.getKey(),cleanJsonNodeXSS(next.getValue()));
}else if(next.getValue().isArray()) {
List<Object> elementList = new ArrayList<>();
Iterator<JsonNode> elements = next.getValue().elements();
while (elements.hasNext()) {
JsonNode childrenNext = elements.next();
Object nodeMap = cleanJsonNodeXSS(childrenNext);
elementList.add(nodeMap);
}
map.put(next.getKey(),elementList);
}
else {
map.put(next.getKey(),next.getValue());
}
}
return map;
}
}