SpringCloud-Gateway实现RSA加解密

Gateway网关作为流量的入口,有的接口可能需要对请求内容加密,返回结果加密,保证数据安全性。

一、RSA介绍

        RSA主要使用大整数分解这个数学难题进行设计,巧妙地利用了数论的概念。给了RSA公钥,首先想到的攻击就是分解模数,给了的因子攻击者可以计算得到,从而也可以计算得到解密指数,我们称这种分解模数的方法为针对RSA的暴力攻击。虽然分解算法已经稳步改进,但是在正确使用RSA情况下,当前的技术水平仍远未对RSA的安全性构成威胁。

相关讨论:RSA的公钥和私钥到底哪个才是用来加密和哪个用来解密? - 知乎

项目相关依赖:SpringCloud-Gateway实现网关_W_Meng_H的博客-CSDN博客

二、相关工具类

RSA工具类:

import com.example.gateway.common.CommonConstant;
import javax.crypto.Cipher;
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.*;

/**
 * @Author: meng
 * @Description: RSA加密工具类
 * @Date: 2023/4/3 9:27
 * @Version: 1.0
 */
public class RSAUtils {

    public static final String CHARSET = "UTF-8";

    // 加密算法
    private final static String ALGORITHM_RSA = "RSA";

    //公钥
    public static final String PUBLIC_KEY = "publicKey";

    //私钥
    public static final String PRIVATE_KEY = "privateKey";

    /**
     * @param modulus
     * @throws NoSuchAlgorithmException
     * @Description: 直接生成公钥、私钥对象
     */
    public static List<Key> getRSAKeyObject(int modulus) throws NoSuchAlgorithmException {
        List<Key> keyList = new ArrayList<>(2);
        // 创建RSA密钥生成器
        KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
        // 设置密钥的大小,此处是RSA算法的模长 = 最大加密数据的大小
        keyPairGen.initialize(modulus);
        KeyPair keyPair = keyPairGen.generateKeyPair();
        // keyPair.getPublic() 生成的是RSAPublic的是咧
        keyList.add(keyPair.getPublic());
        // keyPair.getPrivate() 生成的是RSAPrivateKey的实例
        keyList.add(keyPair.getPrivate());
        return keyList;
    }

    /**
     * @param modulus 模长
     * @return
     * @throws NoSuchAlgorithmException
     * @Description: 生成公钥、私钥的字符串
     */
    public static Map<String, String> getRSAKeyString(int modulus) throws NoSuchAlgorithmException {
        // map装载公钥和私钥
        Map<String, String> keyPairMap = new HashMap<String, String>();
        KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
        keyPairGen.initialize(modulus);
        KeyPair keyPair = keyPairGen.generateKeyPair();
        String publicKey = Base64.getEncoder().encodeToString(keyPair.getPublic().getEncoded());
        String privateKey = Base64.getEncoder().encodeToString(keyPair.getPrivate().getEncoded());
        keyPairMap.put(PUBLIC_KEY, publicKey);
        keyPairMap.put(PRIVATE_KEY, privateKey);
        return keyPairMap;
    }

    /**
     * @throws Exception
     * @Description: Java中RSAPublicKeySpec、X509EncodedKeySpec支持生成RSA公钥
     * 此处使用X509EncodedKeySpec生成
     * @Param: privateKey
     * @Author: wmh
     * @Date: 2023/4/3 10:15
     */
    public static RSAPublicKey getPublicKey(String publicKey) throws Exception {
        KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
        byte[] keyBytes = Base64.getDecoder().decode(publicKey);
        X509EncodedKeySpec spec = new X509EncodedKeySpec(keyBytes);
        return (RSAPublicKey) keyFactory.generatePublic(spec);
    }

    /**
     * @throws Exception
     * @Description: Java中只有RSAPrivateKeySpec、PKCS8EncodedKeySpec支持生成RSA私钥
     * 此处使用PKCS8EncodedKeySpec生成
     * @Param: privateKey
     * @Author: wmh
     * @Date: 2023/4/3 10:15
     */
    public static RSAPrivateKey getPrivateKey(String privateKey) throws Exception {
        KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
        byte[] keyBytes = Base64.getDecoder().decode(privateKey);
        PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
        return (RSAPrivateKey) keyFactory.generatePrivate(spec);
    }

    /**
     * @param data
     * @param publicKey
     * @throws Exception
     * @Description: 公钥加密
     */
    public static String encryptByPublicKey(String data, RSAPublicKey publicKey)
            throws Exception {
        Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
        cipher.init(Cipher.ENCRYPT_MODE, publicKey);
        // 模长n转换成字节数
        int modulusSize = publicKey.getModulus().bitLength() / 8;
        // PKCS Padding长度为11字节,所以实际要加密的数据不能要 - 11byte
        int maxSingleSize = modulusSize - 11;
        // 切分字节数组,每段不大于maxSingleSize
        byte[][] dataArray = splitArray(data.getBytes(CHARSET), maxSingleSize);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        // 分组加密,并将加密后的内容写入输出字节流
        for (byte[] s : dataArray) {
            out.write(cipher.doFinal(s));
        }
        // 使用Base64将字节数组转换String类型
        return Base64.getEncoder().encodeToString(out.toByteArray());
    }

    /**
     * @param data
     * @param privateKey
     * @throws Exception
     * @Description: 私钥解密
     */
    public static String decryptByPrivateKey(String data, RSAPrivateKey privateKey)
            throws Exception {
        Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
        cipher.init(Cipher.DECRYPT_MODE, privateKey);
        // RSA加密算法的模长 n
        int modulusSize = privateKey.getModulus().bitLength() / 8;
        byte[] dataBytes = data.getBytes(CHARSET);
        // 之前加密的时候做了转码,此处需要使用Base64进行解码
        byte[] decodeData = Base64.getDecoder().decode(dataBytes);
        // 切分字节数组,每段不大于modulusSize
        byte[][] splitArrays = splitArray(decodeData, modulusSize);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        for (byte[] arr : splitArrays) {
            out.write(cipher.doFinal(arr));
        }
        return new String(out.toByteArray());
    }

    /**
     * @param data
     * @param privateKey
     * @throws Exception
     * @Description: 公钥加密
     */
    public static String encryptByPrivateKey(String data, RSAPrivateKey privateKey)
            throws Exception {
        Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
        cipher.init(Cipher.ENCRYPT_MODE, privateKey);
        // 模长n转换成字节数
        int modulusSize = privateKey.getModulus().bitLength() / 8;
        // PKCS Padding长度为11字节,所以实际要加密的数据不能要 - 11byte
        int maxSingleSize = modulusSize - 11;
        // 切分字节数组,每段不大于maxSingleSize
        byte[][] dataArray = splitArray(data.getBytes(CHARSET), maxSingleSize);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        // 分组加密,并将加密后的内容写入输出字节流
        for (byte[] s : dataArray) {
            out.write(cipher.doFinal(s));
        }
        // 使用Base64将字节数组转换String类型
        return Base64.getEncoder().encodeToString(out.toByteArray());
    }

    /**
     * @param data
     * @param publicKey
     * @throws Exception
     * @Description: 公钥解密
     */
    public static String decryptByPublicKey(String data, RSAPublicKey publicKey)
            throws Exception {
        Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
        cipher.init(Cipher.DECRYPT_MODE, publicKey);
        // RSA加密算法的模长 n
        int modulusSize = publicKey.getModulus().bitLength() / 8;
        byte[] dataBytes = data.getBytes(CHARSET);
        // 之前加密的时候做了转码,此处需要使用Base64进行解码
        byte[] decodeData = Base64.getDecoder().decode(dataBytes);
        // 切分字节数组,每段不大于modulusSize
        byte[][] splitArrays = splitArray(decodeData, modulusSize);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        for (byte[] arr : splitArrays) {
            out.write(cipher.doFinal(arr));
        }
        return new String(out.toByteArray());
    }

    /**
     * @param data
     * @param len  单个字节数组长度
     * @Description: 按指定长度切分数组
     */
    private static byte[][] splitArray(byte[] data, int len) {
        int dataLen = data.length;
        if (dataLen <= len) {
            return new byte[][]{data};
        }
        byte[][] result = new byte[(dataLen - 1) / len + 1][];
        int resultLen = result.length;
        for (int i = 0; i < resultLen; i++) {
            if (i == resultLen - 1) {
                int slen = dataLen - len * i;
                byte[] single = new byte[slen];
                System.arraycopy(data, len * i, single, 0, slen);
                result[i] = single;
                break;
            }
            byte[] single = new byte[len];
            System.arraycopy(data, len * i, single, 0, len);
            result[i] = single;
        }
        return result;
    }

    public static void main(String[] args) throws Exception {
        Map<String, String> keyStringList = RSAUtils.getRSAKeyString(1024);
        String pukString = CommonConstant.RSA_PUBLIC_KEY;
        String prkString = CommonConstant.RSA_PRIVATE_KEY;
        System.out.println("公钥:" + pukString);
        System.out.println("私钥:" + prkString);
        // 生成公钥、私钥
        RSAPublicKey puk = RSAUtils.getPublicKey(pukString);
        RSAPrivateKey prk = RSAUtils.getPrivateKey(prkString);
        String message = "{\"message\":\"test\",\"name\":\"test\"}";
        String encryptedMsg = RSAUtils.encryptByPublicKey(message, puk);
        String decryptedMsg = RSAUtils.decryptByPrivateKey(encryptedMsg, prk);
        System.out.println("未加密内容:" + message);
        System.out.println("公钥加密内容:" + encryptedMsg);
        System.out.println("私钥解密内容:" + decryptedMsg);
        System.out.println("-----------------------------------");
    }

}

常用变量类:

import com.alibaba.fastjson.JSONObject;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

/**
 * @Author: meng
 * @Description: 常用变量
 * @Date: 2023/3/30 10:29
 * @Version: 1.0
 */
@Component
public class CommonConstant {

    //JWT密钥
    public static final String JWT_TOKEN = "jwt-token";

    //请求头中的token
    public static final String X_TOKEN = "X-TOKEN";

    //请求头中的sign
    public static final String X_SIGN = "X-SIGN";

    public static final String X_APPID = "X-APPID";

    public static final String CODE = "code";

    public static final String MESSAGE = "message";

    public static final String UTF8 = "UTF-8";

    public static final String RSA_PUBLIC_KEY = "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDFJIl4il6nDBlF/3byWB/KXRqfEXkviz7ZvO7TU7JBfh7sFqfgLtJFDSA33+qTHOtYTCjCrwl6oWWX7Aff39HiFW1IBnhKjYdSK5/8ruQY+Y2xbpBMgslA0m2euOv3XPJUXWh0JGBqPllgzvtbtUA1iBELAHVYBACuQPYP2VcPeQIDAQAB";
	public static final String RSA_PRIVATE_KEY = "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAMUkiXiKXqcMGUX/dvJYH8pdGp8ReS+LPtm87tNTskF+HuwWp+Au0kUNIDff6pMc61hMKMKvCXqhZZfsB9/f0eIVbUgGeEqNh1Irn/yu5Bj5jbFukEyCyUDSbZ646/dc8lRdaHQkYGo+WWDO+1u1QDWIEQsAdVgEAK5A9g/ZVw95AgMBAAECgYABvRrBR2ciTgcDCQfBh2lwXXXYpUzOUIoTXYk1r+1IipY3OtPsND2CgmUgWQc2mPCybKmHXgfVXwsIVfqTzOOK+PEMVGYNflUdXgV3hNffRzl/nfPdpqhb2ALu8ftPwiGq5QN2PqaRgY9kM67Ye/cCjFzm/kLIqsNuXLKiQc1ioQJBAO7g4ZBcG/D0IxtiR4RdXYtr4wQc+cmscSKj5RPNBwn0bh9psOSg2loS/wWUmCnYSncsLGgMzPl+yPkTLwGryH0CQQDTRduiOzu6bFdOw6tI6eOxHB5h0kfcim4VT/Huh5RyP+GC7kLBmknbBO/tQXxSDVaG81Pkr+INHxJmctfKik+tAkEAtBIrl0IIAhRXnp3wYXRsPtxeLkyVc5SdWEqKNen5Y2Sx2tY2dbJXx0zIl3FTXz/fqoRPGUSFA5Kydygh6DWRlQJBAMmOfOHB9tJ8Z7LJ85AFKucdt1KFpW8eVbVZZqq0iAeTMBaULfW7tzgO9sJ3Vh6FgQYP//pNXbA883XvnDUrTKUCQQDgLO7mThmy7iqqo0be4a2ycy9fvORFYzSq1t6mTd+gr73CMCy2bTmyv/Qp4QsuPIKea0iE+HA/la5zlM8eAxOq";
    //公共返回方法
    public static Mono<Void> buildResponse(ServerWebExchange exchange, int code, String message) {
        ServerHttpResponse response = exchange.getResponse();
        response.setStatusCode(HttpStatus.OK);
        response.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
        JSONObject jsonObject = new JSONObject();
        jsonObject.put(CODE, code);
        jsonObject.put(MESSAGE, message);
        DataBuffer bodyDataBuffer = response.bufferFactory().wrap(jsonObject.toJSONString().getBytes());
        return response.writeWith(Mono.just(bodyDataBuffer));
    }
}

PS:我们想要实现对请求参数解密,需要解决获取body参数,只能获取一次的问题,网上有很多解决方案了,大家可以自行搜索。

推荐一个:SpringCloud-Gateway获取body参数,解决只能获取一次问题,终极解决方案_kamjin1996的博客-CSDN博客

不想看上边的博客,直接使用如下代码即可:

import com.example.gateway.common.GatewayContext;
import io.netty.buffer.ByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

/**
 * @Author: https://blog.csdn.net/zx156955/article/details/115004910
 * @Description: 请求内容存储 处理请求内容 内容放在gatewayContext中
 * @Date: 2023/3/30 10:11
 * @Version: 1.0
 */
@Component
@Slf4j
public class RequestCoverFilter implements GlobalFilter, Ordered {

    /**
     * default HttpMessageReader
     */
    private static final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();

    /**
     * ReadFormData
     *
     * @param exchange
     * @param chain
     * @return
     */
    private Mono<Void> readFormData(ServerWebExchange exchange, GatewayFilterChain chain,
                                    GatewayContext gatewayContext) {
        final ServerHttpRequest request = exchange.getRequest();
        HttpHeaders headers = request.getHeaders();

        return exchange.getFormData().doOnNext(multiValueMap -> {
            gatewayContext.setFormData(multiValueMap);
            log.debug("[GatewayContext]Read FormData:{}", multiValueMap);
        }).then(Mono.defer(() -> {
            Charset charset = headers.getContentType().getCharset();
            charset = charset == null ? StandardCharsets.UTF_8 : charset;
            String charsetName = charset.name();
            MultiValueMap<String, String> formData = gatewayContext.getFormData();
            /**
             * formData is empty just return
             */
            if (null == formData || formData.isEmpty()) {
                return chain.filter(exchange);
            }
            StringBuilder formDataBodyBuilder = new StringBuilder();
            String entryKey;
            List<String> entryValue;
            try {
                /**
                 * repackage form data
                 */
                for (Map.Entry<String, List<String>> entry : formData.entrySet()) {
                    entryKey = entry.getKey();
                    entryValue = entry.getValue();
                    if (entryValue.size() > 1) {
                        for (String value : entryValue) {
                            formDataBodyBuilder.append(entryKey).append("=")
                                    .append(URLEncoder.encode(value, charsetName)).append("&");
                        }
                    } else {
                        formDataBodyBuilder.append(entryKey).append("=")
                                .append(URLEncoder.encode(entryValue.get(0), charsetName)).append("&");
                    }
                }
            } catch (UnsupportedEncodingException e) {
                // ignore URLEncode Exception
            }
            /**
             * substring with the last char '&'
             */
            String formDataBodyString = "";
            if (formDataBodyBuilder.length() > 0) {
                formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);
            }
            /**
             * get data bytes
             */
            byte[] bodyBytes = formDataBodyString.getBytes(charset);
            int contentLength = bodyBytes.length;
            ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(request) {
                /**
                 * change content-length
                 *
                 * @return
                 */
                @Override
                public HttpHeaders getHeaders() {
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(super.getHeaders());
                    if (contentLength > 0) {
                        httpHeaders.setContentLength(contentLength);
                    } else {
                        httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                    }
                    return httpHeaders;
                }

                /**
                 * read bytes to Flux<Databuffer>
                 *
                 * @return
                 */
                @Override
                public Flux<DataBuffer> getBody() {
                    return DataBufferUtils.read(new ByteArrayResource(bodyBytes),
                            new NettyDataBufferFactory(ByteBufAllocator.DEFAULT), contentLength);
                }
            };
            ServerWebExchange mutateExchange = exchange.mutate().request(decorator).build();
            log.info("[GatewayContext]Rewrite Form Data :{}", formDataBodyString);

            return chain.filter(mutateExchange);
        }));
    }

    /**
     * ReadJsonBody
     *
     * @param exchange
     * @param chain
     * @return
     */
    private Mono<Void> readBody(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
        /**
         * join the body
         */
        return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {
            /*
             * read the body Flux<DataBuffer>, and release the buffer
             * see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
             */
            byte[] bytes = new byte[dataBuffer.readableByteCount()];
            dataBuffer.read(bytes);
            DataBufferUtils.release(dataBuffer);
            Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
                DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
                DataBufferUtils.retain(buffer);
                return Mono.just(buffer);
            });
            /**
             * repackage ServerHttpRequest
             */
            ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
                @Override
                public Flux<DataBuffer> getBody() {
                    return cachedFlux;
                }
            };
            /**
             * mutate exchage with new ServerHttpRequest
             */
            ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
            /**
             * read body string with default messageReaders
             */
            return ServerRequest.create(mutatedExchange, messageReaders).bodyToMono(String.class)
                    .doOnNext(objectValue -> {
                        gatewayContext.setCacheBody(objectValue);
                        log.debug("[GatewayContext]Read JsonBody:{}", objectValue);
                    }).then(chain.filter(mutatedExchange));
        });
    }

    @Override
    public int getOrder() {
        return HIGHEST_PRECEDENCE;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        /**
         * save request path and serviceId into gateway context
         */
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();

        GatewayContext gatewayContext = new GatewayContext();
        String path = request.getPath().pathWithinApplication().value();
        gatewayContext.setPath(path);
        gatewayContext.getFormData().addAll(request.getQueryParams());
        gatewayContext.setIpAddress(String.valueOf(request.getRemoteAddress()));
        HttpHeaders headers = request.getHeaders();
        gatewayContext.setHeaders(headers);
        log.debug("HttpMethod:{},Url:{}", request.getMethod(), request.getURI().getRawPath());

        /// 注意,因为webflux的响应式编程 不能再采取原先的编码方式 即应该先将gatewayContext放入exchange中,否则其他地方可能取不到
        /**
         * save gateway context into exchange
         */
        exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);

        // 处理参数
        MediaType contentType = headers.getContentType();
        long contentLength = headers.getContentLength();
        if (contentLength > 0) {
            if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_JSON_UTF8.equals(contentType)) {
                return readBody(exchange, chain, gatewayContext);
            }
            if (MediaType.APPLICATION_FORM_URLENCODED.equals(contentType)) {
                return readFormData(exchange, chain, gatewayContext);
            }
        }

        log.debug("[GatewayContext]ContentType:{},Gateway context is set with {}", contentType, gatewayContext);
        return chain.filter(exchange);
    }
}
网关上下文:
import lombok.Data;
import org.springframework.http.HttpHeaders;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
 * @Author: meng
 * @Description: 网关上下文
 * @Version: 1.0
 */
@Data
public class GatewayContext {

	public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";

	/**
	 * cache headers
	 */
	private HttpHeaders headers;

	/**
	 * cache json body
	 */
	private String cacheBody;

	/**
	 * cache formdata
	 */
	private MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();

	/**
	 * ipAddress
	 */
	private String ipAddress;

	/**
	 * path
	 */
	private String path;

}

三、 RSA实现对请求参数解密

1、修改yml文件,增加解密过滤器

 2、RSA解密过滤器

import cn.hutool.core.util.StrUtil;
import com.example.gateway.common.CommonConstant;
import com.example.gateway.common.GatewayContext;
import com.example.gateway.utils.RSAUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

import java.security.interfaces.RSAPrivateKey;

/**
 * @Author: meng
 * @Description: RSA实现对请求参数解密
 * @Date: 2023/4/6 15:20
 * @Version: 1.0
 */
@Slf4j
@Component
public class RSADecryptResponseGatewayFilterFactory extends AbstractGatewayFilterFactory {

    @Override
    public GatewayFilter apply(Object config) {
        return (exchange, chain) -> {
            ServerHttpRequest serverHttpRequest = exchange.getRequest();
            HttpHeaders header = serverHttpRequest.getHeaders();
            String decrypt = serverHttpRequest.getHeaders().getFirst("decrypt");
            if (!HttpMethod.POST.matches(serverHttpRequest.getMethodValue())) {
                return chain.filter(exchange);
            }
            byte[] decrypBytes;
            GatewayContext gatewayContext = exchange.getAttribute(GatewayContext.CACHE_GATEWAY_CONTEXT);
            if(StrUtil.isBlank(gatewayContext.getCacheBody())){
                return CommonConstant.buildResponse(exchange, HttpStatus.BAD_REQUEST.value(), "请求参数不能为空");
            }
            try {
                // 获取request body
                String requestBody = gatewayContext.getCacheBody();
                log.info("encryptMsg body :{}", requestBody);
                RSAPrivateKey privateKey = RSAUtils.getPrivateKey(CommonConstant.RSA_PRIVATE_KEY);
                String decryptMsg = RSAUtils.decryptByPrivateKey(requestBody, privateKey);
                log.info("decryptMsg body :{}", decryptMsg);
                gatewayContext.setCacheBody(decryptMsg);
                decrypBytes = decryptMsg.getBytes();
            } catch (Exception e) {
                log.error("RSA 解密失败:{}", e);
                return CommonConstant.buildResponse(exchange, HttpStatus.BAD_REQUEST.value(), "RSA解密失败");
            }
            // 根据解密后的参数重新构建请求
            DataBufferFactory dataBufferFactory = exchange.getResponse().bufferFactory();
            Flux<DataBuffer> bodyFlux = Flux.just(dataBufferFactory.wrap(decrypBytes));
            ServerHttpRequest newRequest = serverHttpRequest.mutate().uri(serverHttpRequest.getURI()).build();
            newRequest = new ServerHttpRequestDecorator(newRequest) {
                @Override
                public Flux<DataBuffer> getBody() {
                    return bodyFlux;
                }
            };

            // 构建新的请求头
            HttpHeaders headers = new HttpHeaders();
            headers.putAll(exchange.getRequest().getHeaders());
            // 由于修改了传递参数,需要重新设置CONTENT_LENGTH,长度是字节长度,不是字符串长度
            int length = decrypBytes.length;
            headers.remove(HttpHeaders.CONTENT_LENGTH);
            headers.setContentLength(length);
            newRequest = new ServerHttpRequestDecorator(newRequest) {
                @Override
                public HttpHeaders getHeaders() {
                    return headers;
                }
            };

            // 把解密后的数据重置到exchange自定义属性中,在之后的日志GlobalLogFilter从此处获取请求参数打印日志
            exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
            return chain.filter(exchange.mutate().request(newRequest).build());
        };
    }

}

 3、测试结果:

四、 RSA实现对返回结果加密

最简单的方法是使用gateway全局过滤器实现,但是我们有可能不需要对所有接口的结果进行加密,我们也可以使用局部过滤器实现。

实践过程中遇到一个问题,ServerHttpResponseDecorator不生效,是因为过滤器顺序问题。

全局过滤器:

自定义的GlobaFilter的order必须小于-1,否则标准 NettyWriteResponseFilter 将在过滤器有机会被调用之前发送响应。

相关讨论:https://github.com/spring-cloud/spring-cloud-gateway/issues/47

局部过滤器:

官方提供了一种方式,详细介绍:Spring Cloud Gateway

如果我们想使用yml+nacos实现动态路由的形式,我们可以借鉴原生类ModifyRequestBodyGatewayFilterFactory实现。

1、修改yml文件,增加加密过滤器

 2、RSA加密过滤器

import cn.hutool.core.util.StrUtil;
import com.example.gateway.common.CommonConstant;
import com.example.gateway.utils.RSAUtils;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.Charset;
import java.security.interfaces.RSAPrivateKey;

/**
 * @Author: meng
 * @Description: RSA加密 - 借鉴原生类ModifyRequestBodyGatewayFilterFactory实现
 * https://docs.spring.io/spring-cloud-gateway/docs/current/reference/html/#the-modifyresponsebody-gatewayfilter-factory
 * @Date: 2023/4/3 11:37
 * @Version: 1.0
 */
@Slf4j
@Component
public class RSAEncryptResponseGatewayFilterFactory extends AbstractGatewayFilterFactory {

	@Override
	public GatewayFilter apply(Object config) {
		RewriteResponseGatewayFilter rewriteResponseGatewayFilter = new RewriteResponseGatewayFilter();
		return rewriteResponseGatewayFilter;
	}

	public class RewriteResponseGatewayFilter implements GatewayFilter, Ordered {

		@Override
		public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
			if (!HttpMethod.POST.matches(exchange.getRequest().getMethodValue())) {
				return chain.filter(exchange);
			}
			ServerHttpResponse originalResponse = exchange.getResponse();
			DataBufferFactory bufferFactory = originalResponse.bufferFactory();
			ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(originalResponse) {
				@Override
				public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
					if (body instanceof Flux) {
						Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
						return super.writeWith(fluxBody.map(dataBuffer -> {
							byte[] content = new byte[dataBuffer.readableByteCount()];
							dataBuffer.read(content);
							DataBufferUtils.release(dataBuffer);
							String responseStr = new String(content, Charset.forName(CommonConstant.UTF8));
							//RSA加密response值
							log.info("RSA response:{}", responseStr);
							byte[] newContent = new byte[0];
							try {
								RSAPrivateKey privateKey = RSAUtils.getPrivateKey(CommonConstant.RSA_PRIVATE_KEY);
								String encryptMsg = RSAUtils.encryptByPrivateKey(responseStr, privateKey);
								log.info("RSA encryptMsg:{}", encryptMsg);
								newContent = encryptMsg.getBytes(CommonConstant.UTF8);
							} catch (Exception e) {
								log.error("RSA fail:{}", e);
								throw new RuntimeException(e);
							}
							return bufferFactory.wrap(newContent);
						}));
					}
					return super.writeWith(body);
				}
			};
			return chain.filter(exchange.mutate().response(decoratedResponse).build());
		}

		@Override
		public int getOrder() {
			return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 1;
		}

	}
}

3、测试结果

 

评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值