配置类
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.List;
@ConfigurationProperties(prefix = "pbs.xss")
public class XssProperties {
private boolean enabled =true;
// 需要排除的xss攻击链接
private List<String> excludeUrls;
// 需要进行xss过滤的链接
private List<String> cludeUrls;
public boolean isEnabled() {
return enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public List<String> getExcludeUrls() {
return excludeUrls;
}
public void setExcludeUrls(List<String> excludeUrls) {
this.excludeUrls = excludeUrls;
}
public List<String> getCludeUrls() {
return cludeUrls;
}
public void setCludeUrls(List<String> cludeUrls) {
this.cludeUrls = cludeUrls;
}
}
yml配置:
pbs: xss: excludeUrls: /a/b,/c/d cludeUrls: /xssController/** enabled: true
自定义request:
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.*;
import java.nio.charset.Charset;
public class XssHttpServletRequestWrapper extends ContentCachingRequestWrapper {
//原始请求
private final HttpServletRequest orgRequest;
//请求报文体
private final byte[] content;
//输入流
private XssInputStream inputStream;
public XssHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
orgRequest = request;
ServletInputStream servletInputStream = request.getInputStream();
byte[] bytes = StreamUtils.copyToByteArray(servletInputStream);
Charset charset = Charset.defaultCharset();
String str = new String(bytes, charset);
str = XssFilterUtil.cleanJson(str);
content = str.getBytes(charset);
//将请求报文写入attribute
request.setAttribute("pbs.common.exception.body", str);
}
@Override
public String getParameter(String name) {
name = XssFilterUtil.clean(name);
String value = super.getParameter(name);
if (org.springframework.util.StringUtils.hasText(value)) {
value = XssFilterUtil.clean(value);
}
return value;
}
@Override
public String[] getParameterValues(String name) {
String[] arr = super.getParameterValues(name);
if (arr!=null){
for (int i = 0; i < arr.length; i++) {
arr[i]=XssFilterUtil.clean(arr[i]);
}
}
return arr;
}
public byte[] getContent(){
return this.content;
}
@Override
public String getHeader(String name) {
name=XssFilterUtil.clean(name);
String value = super.getHeader(name);
if (org.springframework.util.StringUtils.hasText(value)){
value=XssFilterUtil.clean(value);
}
return value;
}
@Override
public ServletInputStream getInputStream() throws IOException {
if (inputStream==null){
inputStream=new XssInputStream(content);
}
return inputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
}
// xss过滤输入流
class XssInputStream extends ServletInputStream {
private final ByteArrayInputStream bis;
public XssInputStream(byte[] content) {
this.bis = new ByteArrayInputStream(content);
}
@Override
public boolean isFinished() {
return this.bis.available() == 0;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
int ch = this.bis.read();
return ch;
}
}
class XssFilterUtil {
private static final char[] ESCAPE_JSON_CHARS = new char[]{'&', '<', '>'};
// 清理xss
public static String clean(String content) {
//return StringEscapeUtils.escapeHtml4(content) org apache commons lang3中这个类已经过时,引用了<artifactId>commons-text</artifactId>
return StringEscapeUtils.escapeHtml4(content);
}
// 清理json的xss攻击,与普通的不同,为防止json报文不可用,需要减少对双引号的过滤
public static String cleanJson(String json) {
if (StringUtils.containsNone(json, ESCAPE_JSON_CHARS)) {
return json;
}
StringWriter stringWriter = new StringWriter(json.length() + (json.length() / 10));
try {
escape(stringWriter, json);
} catch (IOException e) {
System.out.println(e);
}
return stringWriter.toString();
}
private static void escape(Writer writer, String str) throws IOException {
int len = str.length();
for (int i = 0; i < len; i++) {
char c = str.charAt(i);
// 仅针对& < > 三种字符做escape
if (c == '&') {
writer.write("&");
} else if (c == '<') {
writer.write("<");
} else if (c == '>') {
writer.write(">");
} else {
if (c > 0x7F) {
writer.write("&#");
writer.write(Integer.toString(c, 10));
writer.write(";");
} else {
writer.write(c);
}
}
}
}
}
配置过滤器:
import org.apache.commons.lang3.SerializationUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Configuration
@EnableConfigurationProperties(XssProperties.class)
@Order(999)
@ConditionalOnProperty(prefix = "pbs.xss", name = "enabled", havingValue = "true", matchIfMissing = true)
public class XssFilter extends OncePerRequestFilter {
@Autowired
private XssProperties properties;
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
System.out.println(properties.getExcludeUrls());
if (handleExcledeUrls(request, response)) {
filterChain.doFilter(request, response);
}
XssHttpServletRequestWrapper xssHttpServletRequestWrapper = new XssHttpServletRequestWrapper(request);
filterChain.doFilter(xssHttpServletRequestWrapper, response);
}
private boolean handleExcledeUrls(HttpServletRequest request, HttpServletResponse response) {
List<String> excludeUrls = properties.getExcludeUrls();
if (excludeUrls != null && !excludeUrls.isEmpty()) {
String url = request.getServletPath();
Iterator<String> iterator = excludeUrls.iterator();
Matcher m = null;
do {
if (iterator.hasNext()) {
String str = iterator.next();
Pattern p = Pattern.compile("^" + SerializationUtils.clone(str));
m = p.matcher(url);
} else {
return false;
}
} while (!m.find());
return true;
}
return false;
}
}