背景:公司项目为微服务项目,使用了SpringCloudGateway,目前有需要防护xss攻击请求的需求
实现方案:继承AbstractGatewayFilterFactory,通过yml文件自定义配置某些需要xss防护的服务。xss匹配:自定义正则表达式匹配
之前写的一版,,自己创建新DataBuffer来读取requestbody里的内容,上生产堆外内存泄露了。
改版后的代码如下,还请各位大佬指正
gateway yml配置项:
XssRequestFirewallGatewayFilterFactory;(参考:org.springframework.cloud.gateway.filter.factory.rewrite.ModifyRequestBodyGatewayFilterFactory)
import com.alibaba.fastjson.JSONObject;
import com.mg.mg.gateway.utils.CommonUtil;
import com.mg.mg.gateway.utils.XssCleanRuleUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.HasRouteId;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.io.buffer.*;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
@Component
@Slf4j
public class XssRequestFirewallGatewayFilterFactory
extends AbstractGatewayFilterFactory<XssRequestFirewallGatewayFilterFactory.Config> {
private final List<HttpMessageReader<?>> messageReaders;
public XssRequestFirewallGatewayFilterFactory() {
super(XssRequestFirewallGatewayFilterFactory.Config.class);
this.messageReaders = HandlerStrategies.withDefaults().messageReaders();
}
public XssRequestFirewallGatewayFilterFactory(List<HttpMessageReader<?>> messageReaders) {
super(XssRequestFirewallGatewayFilterFactory.Config.class);
this.messageReaders = messageReaders;
}
private static final String CONTENT_TYPE = "Content-Type";
private static final String CONTENT_TYPE_JSON = "application/json";
@Override
public GatewayFilter apply(Config config) {
return (exchange, chain) -> {
ServerHttpRequest request = exchange.getRequest();
URI uri = request.getURI();
MultiValueMap<String, String> queryParams = request.getQueryParams();
switch (Objects.requireNonNull(request.getMethod())) {
case GET:
if (checkParamsXss(exchange, request, uri, queryParams)) {
return exchange.getResponse().setComplete();
}
URI newUri1 = UriComponentsBuilder.fromUri(uri).build(true).toUri();
ServerHttpRequest newRequest1 = exchange.getRequest().mutate().uri(newUri1).build();
return chain.filter(exchange.mutate().request(newRequest1).build());
case POST:
String contentType = request.getHeaders().getFirst(CONTENT_TYPE);
//只处理contentType为application/json的post请求
if (StringUtils.isNotBlank(contentType) && contentType.contains(CONTENT_TYPE_JSON)) {
if (checkParamsXss(exchange, request, uri, queryParams)) {
return exchange.getResponse().setComplete();
}
//Class inClass = config.getInClass();
ServerRequest serverRequest = ServerRequest.create(exchange, messageReaders);
AtomicReference<String> bodyAc = new AtomicReference<>();
// TODO: flux or mono
Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
.flatMap(originalBody -> {
bodyAc.set(originalBody);
return Mono.just(originalBody);
});
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
HttpHeaders headers = new HttpHeaders();
headers.putAll(exchange.getRequest().getHeaders());
// the new content type will be computed by bodyInserter
// and then set in the request decorator
headers.remove(HttpHeaders.CONTENT_LENGTH);
CachedBodyOutputMessageInner outputMessage = new CachedBodyOutputMessageInner(exchange, headers);
return bodyInserter.insert(outputMessage, new BodyInserterContext())
// .log("modify_request", Level.INFO)
.then(Mono.defer(() -> {
String bodyString = bodyAc.get();
if (XssCleanRuleUtils.xssMatch(bodyString)) {
log.info("检测到xss注入攻击,uri:{},bodyString:{},ip:{}", uri, bodyString, CommonUtil.getIpAddr(request));
ServerWebExchangeUtils.setResponseStatus(exchange, HttpStatus.FORBIDDEN);
return exchange.getResponse().setComplete();
}
ServerHttpRequest decorator = decorate(exchange, headers, outputMessage);
return chain.filter(exchange.mutate().request(decorator).build());
})).onErrorResume((Function<Throwable, Mono<Void>>) throwable -> release(exchange,
outputMessage, throwable));
}
break;
default:
break;
}
return chain.filter(exchange);
};
}
private boolean checkParamsXss(ServerWebExchange exchange, ServerHttpRequest request, URI uri, MultiValueMap<String, String> queryParams) {
for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
List<String> value = entry.getValue();
if (XssCleanRuleUtils.xssMatch(Arrays.toString(value.toArray()))) {
log.info("检测到xss注入攻击,uri:{},queryParams:{},ip:{}", uri, JSONObject.toJSONString(queryParams), CommonUtil.getIpAddr(request));
ServerWebExchangeUtils.setResponseStatus(exchange, HttpStatus.FORBIDDEN);
return true;
}
}
return false;
}
protected Mono<Void> release(ServerWebExchange exchange, CachedBodyOutputMessageInner outputMessage,
Throwable throwable) {
if (outputMessage.isCached()) {
return outputMessage.getBody().map(DataBufferUtils::release).then(Mono.error(throwable));
}
return Mono.error(throwable);
}
ServerHttpRequestDecorator decorate(ServerWebExchange exchange, HttpHeaders headers,
CachedBodyOutputMessageInner outputMessage) {
return new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public HttpHeaders getHeaders() {
long contentLength = headers.getContentLength();
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
if (contentLength > 0) {
httpHeaders.setContentLength(contentLength);
} else {
// TODO: this causes a 'HTTP/1.1 411 Length Required' // on
// httpbin.org
httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
}
return httpHeaders;
}
@Override
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
};
}
private <T> T getOrDefault(T configValue, T defaultValue) {
return (configValue != null) ? configValue : defaultValue;
}
public static class CachedBodyOutputMessageInner implements ReactiveHttpOutputMessage {
private final DataBufferFactory bufferFactory;
private final HttpHeaders httpHeaders;
private boolean cached = false;
private Flux<DataBuffer> body = Flux
.error(new IllegalStateException("The body is not set. " + "Did handling complete with success?"));
public CachedBodyOutputMessageInner(ServerWebExchange exchange, HttpHeaders httpHeaders) {
this.bufferFactory = exchange.getResponse().bufferFactory();
this.httpHeaders = httpHeaders;
}
@Override
public void beforeCommit(Supplier<? extends Mono<Void>> action) {
}
@Override
public boolean isCommitted() {
return false;
}
boolean isCached() {
return this.cached;
}
@Override
public HttpHeaders getHeaders() {
return this.httpHeaders;
}
@Override
public DataBufferFactory bufferFactory() {
return this.bufferFactory;
}
/**
* Return the request body, or an error stream if the body was never set or when.
*
* @return body as {@link Flux}
*/
public Flux<DataBuffer> getBody() {
return this.body;
}
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
this.body = Flux.from(body);
this.cached = true;
return Mono.empty();
}
@Override
public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
return writeWith(Flux.from(body).flatMap(p -> p));
}
@Override
public Mono<Void> setComplete() {
return writeWith(Flux.empty());
}
}
public static class Config implements HasRouteId {
private String routeId;
@Override
public void setRouteId(String routeId) {
this.routeId = routeId;
}
@Override
public String getRouteId() {
return this.routeId;
}
}
}
XssCleanRuleUtils
import org.springframework.util.StringUtils;
import java.util.StringJoiner;
import java.util.regex.Pattern;
public class XssCleanRuleUtils {
private XssCleanRuleUtils() {
}
private static StringJoiner joiner = new StringJoiner("|");
private static Pattern xssPattern = null;
private static final String[] xssScriptRegArr = {
"<script>(.*?)</script>",
"src[\r\n]*=[\r\n]*\\'(.*?)\\'",
"</script>",
"<script(.*?)>",
"eval\\((.*?)\\)",
"expression\\((.*?)\\)",
"javascript:",
"vbscript:",
"onload(.*?)=",
"\\b(and|exec|insert|select|drop|grant|alter|delete|update|count|chr|mid|master|truncate|char|declare|or)\\b|(\\*|;|\\+|'|%)"
};
static {
for (String reg : xssScriptRegArr) {
joiner.add(reg);
}
xssPattern = Pattern.compile(joiner.toString(), Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
}
public static void main(String[] args) {
System.out.println(xssMatch("select:567"));
}
/**
* xssMatch
*
* @param value
* @return
*/
public static boolean xssMatch(String value) {
value = StringUtils.trimAllWhitespace(value);
if (StringUtils.isEmpty(value)) {
return false;
}
return xssPattern.matcher(value).find();
}
/**
* xssClean
*
* @param value
* @return
*/
public static String xssClean(String value) {
if (StringUtils.isEmpty(value)) {
return value;
}
value = xssPattern.matcher(value).replaceAll("<XssScript>***</XssScript>");
value = value.replace("<", "<").replace(">", ">");
return value;
}
}
CommonUtil
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.*;
public class CommonUtil {
/**
* 生成指定位数的随机数
*
* @param length
* @return
*/
public static String getRandomString(int length) {
String base = "0123456789";
Random random = new Random();
StringBuffer sb = new StringBuffer();
int number = 0;
for (int i = 0; i < length; i++) {
number = random.nextInt(base.length());
sb.append(base.charAt(number));
}
return sb.toString();
}
/**
* 获取客户端IP地址
*
* @param request
* @return
*/
public static String getIpAddr(ServerHttpRequest request) {
String header1 = "x-forwarded-for";
String header2 = "Proxy-Client-IP";
String header3 = "WL-Proxy-Client-IP";
String ip1 = "unknown";
String ip2 = "127.0.0.1";
String ip3 = "0:0:0:0:0:0:0:1";
HttpHeaders headers = request.getHeaders();
String ipAddress = String.valueOf(headers.get(header1));
if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
ipAddress = String.valueOf(headers.get(header2));
}
if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
ipAddress = String.valueOf(headers.get(header3));
}
if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddress().getHostName();
if (ip2.equals(ipAddress) || ip3.equals(ipAddress)) {
//根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ipAddress = inet.getHostAddress();
}
}
//对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
int i = 15;
if (ipAddress != null && ipAddress.length() > i) {
String split = ",";
if (ipAddress.indexOf(split) > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(split));
}
}
return ipAddress;
}
}