Spring Cloud Gateway 过滤器实现XSS防护

背景:公司项目为微服务项目,使用了SpringCloudGateway,目前有需要防护xss攻击请求的需求

实现方案:继承AbstractGatewayFilterFactory,通过yml文件自定义配置某些需要xss防护的服务。xss匹配:自定义正则表达式匹配

之前写的一版,,自己创建新DataBuffer来读取requestbody里的内容,上生产堆外内存泄露了。

改版后的代码如下,还请各位大佬指正

gateway yml配置项:

 XssRequestFirewallGatewayFilterFactory;(参考:org.springframework.cloud.gateway.filter.factory.rewrite.ModifyRequestBodyGatewayFilterFactory)

import com.alibaba.fastjson.JSONObject;
import com.mg.mg.gateway.utils.CommonUtil;
import com.mg.mg.gateway.utils.XssCleanRuleUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.HasRouteId;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.io.buffer.*;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

@Component
@Slf4j
public class XssRequestFirewallGatewayFilterFactory
        extends AbstractGatewayFilterFactory<XssRequestFirewallGatewayFilterFactory.Config> {

    private final List<HttpMessageReader<?>> messageReaders;

    public XssRequestFirewallGatewayFilterFactory() {
        super(XssRequestFirewallGatewayFilterFactory.Config.class);
        this.messageReaders = HandlerStrategies.withDefaults().messageReaders();
    }

    public XssRequestFirewallGatewayFilterFactory(List<HttpMessageReader<?>> messageReaders) {
        super(XssRequestFirewallGatewayFilterFactory.Config.class);
        this.messageReaders = messageReaders;
    }

    private static final String CONTENT_TYPE = "Content-Type";

    private static final String CONTENT_TYPE_JSON = "application/json";

    @Override
    public GatewayFilter apply(Config config) {
        return (exchange, chain) -> {
            ServerHttpRequest request = exchange.getRequest();
            URI uri = request.getURI();
            MultiValueMap<String, String> queryParams = request.getQueryParams();
            switch (Objects.requireNonNull(request.getMethod())) {
                case GET:
                    if (checkParamsXss(exchange, request, uri, queryParams)) {
                        return exchange.getResponse().setComplete();
                    }
                    URI newUri1 = UriComponentsBuilder.fromUri(uri).build(true).toUri();
                    ServerHttpRequest newRequest1 = exchange.getRequest().mutate().uri(newUri1).build();
                    return chain.filter(exchange.mutate().request(newRequest1).build());
                case POST:
                    String contentType = request.getHeaders().getFirst(CONTENT_TYPE);
                    //只处理contentType为application/json的post请求
                    if (StringUtils.isNotBlank(contentType) && contentType.contains(CONTENT_TYPE_JSON)) {
                        if (checkParamsXss(exchange, request, uri, queryParams)) {
                            return exchange.getResponse().setComplete();
                        }
                        //Class inClass = config.getInClass();
                        ServerRequest serverRequest = ServerRequest.create(exchange, messageReaders);

                        AtomicReference<String> bodyAc = new AtomicReference<>();
                        // TODO: flux or mono
                        Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
                                .flatMap(originalBody -> {
                                    bodyAc.set(originalBody);
                                    return Mono.just(originalBody);
                                });
                        BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
                        HttpHeaders headers = new HttpHeaders();
                        headers.putAll(exchange.getRequest().getHeaders());

                        // the new content type will be computed by bodyInserter
                        // and then set in the request decorator
                        headers.remove(HttpHeaders.CONTENT_LENGTH);

                        CachedBodyOutputMessageInner outputMessage = new CachedBodyOutputMessageInner(exchange, headers);
                        return bodyInserter.insert(outputMessage, new BodyInserterContext())
                                // .log("modify_request", Level.INFO)
                                .then(Mono.defer(() -> {
                                    String bodyString = bodyAc.get();
                                    if (XssCleanRuleUtils.xssMatch(bodyString)) {
                                        log.info("检测到xss注入攻击,uri:{},bodyString:{},ip:{}", uri, bodyString, CommonUtil.getIpAddr(request));
                                        ServerWebExchangeUtils.setResponseStatus(exchange, HttpStatus.FORBIDDEN);
                                        return exchange.getResponse().setComplete();
                                    }
                                    ServerHttpRequest decorator = decorate(exchange, headers, outputMessage);
                                    return chain.filter(exchange.mutate().request(decorator).build());
                                })).onErrorResume((Function<Throwable, Mono<Void>>) throwable -> release(exchange,
                                        outputMessage, throwable));
                    }
                    break;
                default:
                    break;
            }
            return chain.filter(exchange);
        };
    }

    private boolean checkParamsXss(ServerWebExchange exchange, ServerHttpRequest request, URI uri, MultiValueMap<String, String> queryParams) {
        for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
            List<String> value = entry.getValue();
            if (XssCleanRuleUtils.xssMatch(Arrays.toString(value.toArray()))) {
                log.info("检测到xss注入攻击,uri:{},queryParams:{},ip:{}", uri, JSONObject.toJSONString(queryParams), CommonUtil.getIpAddr(request));
                ServerWebExchangeUtils.setResponseStatus(exchange, HttpStatus.FORBIDDEN);
                return true;
            }
        }
        return false;
    }

    protected Mono<Void> release(ServerWebExchange exchange, CachedBodyOutputMessageInner outputMessage,
                                 Throwable throwable) {
        if (outputMessage.isCached()) {
            return outputMessage.getBody().map(DataBufferUtils::release).then(Mono.error(throwable));
        }
        return Mono.error(throwable);
    }

    ServerHttpRequestDecorator decorate(ServerWebExchange exchange, HttpHeaders headers,
                                        CachedBodyOutputMessageInner outputMessage) {
        return new ServerHttpRequestDecorator(exchange.getRequest()) {
            @Override
            public HttpHeaders getHeaders() {
                long contentLength = headers.getContentLength();
                HttpHeaders httpHeaders = new HttpHeaders();
                httpHeaders.putAll(headers);
                if (contentLength > 0) {
                    httpHeaders.setContentLength(contentLength);
                } else {
                    // TODO: this causes a 'HTTP/1.1 411 Length Required' // on
                    // httpbin.org
                    httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                }
                return httpHeaders;
            }

            @Override
            public Flux<DataBuffer> getBody() {
                return outputMessage.getBody();
            }
        };
    }

    private <T> T getOrDefault(T configValue, T defaultValue) {
        return (configValue != null) ? configValue : defaultValue;
    }

    public static class CachedBodyOutputMessageInner implements ReactiveHttpOutputMessage {

        private final DataBufferFactory bufferFactory;

        private final HttpHeaders httpHeaders;

        private boolean cached = false;

        private Flux<DataBuffer> body = Flux
                .error(new IllegalStateException("The body is not set. " + "Did handling complete with success?"));

        public CachedBodyOutputMessageInner(ServerWebExchange exchange, HttpHeaders httpHeaders) {
            this.bufferFactory = exchange.getResponse().bufferFactory();
            this.httpHeaders = httpHeaders;
        }

        @Override
        public void beforeCommit(Supplier<? extends Mono<Void>> action) {

        }

        @Override
        public boolean isCommitted() {
            return false;
        }

        boolean isCached() {
            return this.cached;
        }

        @Override
        public HttpHeaders getHeaders() {
            return this.httpHeaders;
        }

        @Override
        public DataBufferFactory bufferFactory() {
            return this.bufferFactory;
        }

        /**
         * Return the request body, or an error stream if the body was never set or when.
         *
         * @return body as {@link Flux}
         */
        public Flux<DataBuffer> getBody() {
            return this.body;
        }

        @Override
        public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
            this.body = Flux.from(body);
            this.cached = true;
            return Mono.empty();
        }

        @Override
        public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
            return writeWith(Flux.from(body).flatMap(p -> p));
        }

        @Override
        public Mono<Void> setComplete() {
            return writeWith(Flux.empty());
        }

    }

    public static class Config implements HasRouteId {

        private String routeId;

        @Override
        public void setRouteId(String routeId) {
            this.routeId = routeId;
        }

        @Override
        public String getRouteId() {
            return this.routeId;
        }

    }
}
XssCleanRuleUtils

import org.springframework.util.StringUtils;

import java.util.StringJoiner;
import java.util.regex.Pattern;

public class XssCleanRuleUtils {

    private XssCleanRuleUtils() {

    }

    private static StringJoiner joiner = new StringJoiner("|");
    private static Pattern xssPattern = null;
    private static final String[] xssScriptRegArr = {
            "<script>(.*?)</script>",
            "src[\r\n]*=[\r\n]*\\'(.*?)\\'",
            "</script>",
            "<script(.*?)>",
            "eval\\((.*?)\\)",
            "expression\\((.*?)\\)",
            "javascript:",
            "vbscript:",
            "onload(.*?)=",
            "\\b(and|exec|insert|select|drop|grant|alter|delete|update|count|chr|mid|master|truncate|char|declare|or)\\b|(\\*|;|\\+|'|%)"
    };

    static {
        for (String reg : xssScriptRegArr) {
            joiner.add(reg);
        }

        xssPattern = Pattern.compile(joiner.toString(), Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
    }

    public static void main(String[] args) {
        System.out.println(xssMatch("select:567"));
    }

    /**
     * xssMatch
     *
     * @param value
     * @return
     */
    public static boolean xssMatch(String value) {
        value = StringUtils.trimAllWhitespace(value);
        if (StringUtils.isEmpty(value)) {
            return false;
        }
        return xssPattern.matcher(value).find();
    }

    /**
     * xssClean
     *
     * @param value
     * @return
     */
    public static String xssClean(String value) {
        if (StringUtils.isEmpty(value)) {
            return value;
        }

        value = xssPattern.matcher(value).replaceAll("&lt;XssScript&gt;***&lt;/XssScript&gt;");
        value = value.replace("<", "&lt;").replace(">", "&gt;");

        return value;
    }
}
CommonUtil
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.*;

public class CommonUtil {

    /**
     * 生成指定位数的随机数
     *
     * @param length
     * @return
     */
    public static String getRandomString(int length) {
        String base = "0123456789";
        Random random = new Random();
        StringBuffer sb = new StringBuffer();
        int number = 0;
        for (int i = 0; i < length; i++) {
            number = random.nextInt(base.length());
            sb.append(base.charAt(number));
        }
        return sb.toString();
    }

    /**
     * 获取客户端IP地址
     *
     * @param request
     * @return
     */
    public static String getIpAddr(ServerHttpRequest request) {
        String header1 = "x-forwarded-for";
        String header2 = "Proxy-Client-IP";
        String header3 = "WL-Proxy-Client-IP";
        String ip1 = "unknown";
        String ip2 = "127.0.0.1";
        String ip3 = "0:0:0:0:0:0:0:1";

        HttpHeaders headers = request.getHeaders();
        String ipAddress = String.valueOf(headers.get(header1));
        if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
            ipAddress = String.valueOf(headers.get(header2));
        }
        if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
            ipAddress = String.valueOf(headers.get(header3));
        }
        if (ipAddress == null || ipAddress.length() == 0 || ip1.equalsIgnoreCase(ipAddress)) {
            ipAddress = request.getRemoteAddress().getHostName();
            if (ip2.equals(ipAddress) || ip3.equals(ipAddress)) {
                //根据网卡取本机配置的IP
                InetAddress inet = null;
                try {
                    inet = InetAddress.getLocalHost();
                } catch (UnknownHostException e) {
                    e.printStackTrace();
                }
                ipAddress = inet.getHostAddress();
            }
        }
        //对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
        int i = 15;
        if (ipAddress != null && ipAddress.length() > i) {
            String split = ",";
            if (ipAddress.indexOf(split) > 0) {
                ipAddress = ipAddress.substring(0, ipAddress.indexOf(split));
            }
        }
        return ipAddress;
    }

}

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Spring Cloud Gateway可以通过编写自定义的过滤器实现XSS过滤。 首先,我们需要创建一个XSS过滤器类,实现`GlobalFilter`和`Ordered`接口: ```java @Component public class XssGlobalFilter implements GlobalFilter, Ordered { @Override public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) { ServerHttpRequest request = exchange.getRequest(); HttpHeaders headers = request.getHeaders(); MediaType contentType = headers.getContentType(); HttpMethod method = request.getMethod(); if (contentType != null && contentType.isCompatibleWith(MediaType.APPLICATION_JSON) && HttpMethod.POST.equals(method)) { return chain.filter(exchange.mutate().request(new XssServerHttpRequest(request)).build()); } return chain.filter(exchange); } @Override public int getOrder() { return -1; } } ``` 这里,我们首先判断请求的Content-Type是否为`application/json`,并且请求方法是否为POST,如果是,则将请求的`ServerHttpRequest`替换为我们自定义的`XssServerHttpRequest`,该类继承自`ServerHttpRequestDecorator`,在该类中对请求体进行XSS过滤,代码如下: ```java public class XssServerHttpRequest extends ServerHttpRequestDecorator { public XssServerHttpRequest(ServerHttpRequest delegate) { super(delegate); } @Override public Flux<DataBuffer> getBody() { Flux<DataBuffer> body = super.getBody(); return body.map(dataBuffer -> { CharBuffer charBuffer = StandardCharsets.UTF_8.decode(dataBuffer.asByteBuffer()); String bodyContent = charBuffer.toString(); // 进行XSS过滤 String filteredBodyContent = Jsoup.clean(bodyContent, Whitelist.none()); byte[] bytes = filteredBodyContent.getBytes(StandardCharsets.UTF_8); DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes); DataBufferUtils.release(dataBuffer); return buffer; }); } } ``` 在该类中,我们首先将`DataBuffer`转换成`CharBuffer`,再将其转换成字符串,然后使用Jsoup对字符串进行XSS过滤,最后再将过滤后的字符串转换成`DataBuffer`返回。 最后,我们需要将这个过滤器添加到Spring Cloud Gateway过滤器链中,在配置类中添加: ```java @Configuration public class GatewayConfig { @Bean public XssGlobalFilter xssGlobalFilter() { return new XssGlobalFilter(); } @Bean public RouteLocator customRouteLocator(RouteLocatorBuilder builder) { return builder.routes() // 添加自定义路由 .route(r -> r.path("/api/**").uri("lb://service-provider")) .build(); } } ``` 这样,当请求Content-Type为`application/json`,并且请求方法为POST时,请求体中的HTML标签就会被过滤掉,从而实现XSS过滤。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值