Gateway网关作为流量的入口,有的接口可能需要对请求内容加密,返回结果加密,保证数据安全性。
一、RSA介绍
RSA主要使用大整数分解这个数学难题进行设计,巧妙地利用了数论的概念。给了RSA公钥,首先想到的攻击就是分解模数,给了的因子攻击者可以计算得到,从而也可以计算得到解密指数,我们称这种分解模数的方法为针对RSA的暴力攻击。虽然分解算法已经稳步改进,但是在正确使用RSA情况下,当前的技术水平仍远未对RSA的安全性构成威胁。
相关讨论:RSA的公钥和私钥到底哪个才是用来加密和哪个用来解密? - 知乎
项目相关依赖:SpringCloud-Gateway实现网关_W_Meng_H的博客-CSDN博客
二、相关工具类
RSA工具类:
import com.example.gateway.common.CommonConstant;
import javax.crypto.Cipher;
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.*;
/**
* @Author: meng
* @Description: RSA加密工具类
* @Date: 2023/4/3 9:27
* @Version: 1.0
*/
public class RSAUtils {
public static final String CHARSET = "UTF-8";
// 加密算法
private final static String ALGORITHM_RSA = "RSA";
//公钥
public static final String PUBLIC_KEY = "publicKey";
//私钥
public static final String PRIVATE_KEY = "privateKey";
/**
* @param modulus
* @throws NoSuchAlgorithmException
* @Description: 直接生成公钥、私钥对象
*/
public static List<Key> getRSAKeyObject(int modulus) throws NoSuchAlgorithmException {
List<Key> keyList = new ArrayList<>(2);
// 创建RSA密钥生成器
KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
// 设置密钥的大小,此处是RSA算法的模长 = 最大加密数据的大小
keyPairGen.initialize(modulus);
KeyPair keyPair = keyPairGen.generateKeyPair();
// keyPair.getPublic() 生成的是RSAPublic的是咧
keyList.add(keyPair.getPublic());
// keyPair.getPrivate() 生成的是RSAPrivateKey的实例
keyList.add(keyPair.getPrivate());
return keyList;
}
/**
* @param modulus 模长
* @return
* @throws NoSuchAlgorithmException
* @Description: 生成公钥、私钥的字符串
*/
public static Map<String, String> getRSAKeyString(int modulus) throws NoSuchAlgorithmException {
// map装载公钥和私钥
Map<String, String> keyPairMap = new HashMap<String, String>();
KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
keyPairGen.initialize(modulus);
KeyPair keyPair = keyPairGen.generateKeyPair();
String publicKey = Base64.getEncoder().encodeToString(keyPair.getPublic().getEncoded());
String privateKey = Base64.getEncoder().encodeToString(keyPair.getPrivate().getEncoded());
keyPairMap.put(PUBLIC_KEY, publicKey);
keyPairMap.put(PRIVATE_KEY, privateKey);
return keyPairMap;
}
/**
* @throws Exception
* @Description: Java中RSAPublicKeySpec、X509EncodedKeySpec支持生成RSA公钥
* 此处使用X509EncodedKeySpec生成
* @Param: privateKey
* @Author: wmh
* @Date: 2023/4/3 10:15
*/
public static RSAPublicKey getPublicKey(String publicKey) throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
byte[] keyBytes = Base64.getDecoder().decode(publicKey);
X509EncodedKeySpec spec = new X509EncodedKeySpec(keyBytes);
return (RSAPublicKey) keyFactory.generatePublic(spec);
}
/**
* @throws Exception
* @Description: Java中只有RSAPrivateKeySpec、PKCS8EncodedKeySpec支持生成RSA私钥
* 此处使用PKCS8EncodedKeySpec生成
* @Param: privateKey
* @Author: wmh
* @Date: 2023/4/3 10:15
*/
public static RSAPrivateKey getPrivateKey(String privateKey) throws Exception {
KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
byte[] keyBytes = Base64.getDecoder().decode(privateKey);
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
return (RSAPrivateKey) keyFactory.generatePrivate(spec);
}
/**
* @param data
* @param publicKey
* @throws Exception
* @Description: 公钥加密
*/
public static String encryptByPublicKey(String data, RSAPublicKey publicKey)
throws Exception {
Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
cipher.init(Cipher.ENCRYPT_MODE, publicKey);
// 模长n转换成字节数
int modulusSize = publicKey.getModulus().bitLength() / 8;
// PKCS Padding长度为11字节,所以实际要加密的数据不能要 - 11byte
int maxSingleSize = modulusSize - 11;
// 切分字节数组,每段不大于maxSingleSize
byte[][] dataArray = splitArray(data.getBytes(CHARSET), maxSingleSize);
ByteArrayOutputStream out = new ByteArrayOutputStream();
// 分组加密,并将加密后的内容写入输出字节流
for (byte[] s : dataArray) {
out.write(cipher.doFinal(s));
}
// 使用Base64将字节数组转换String类型
return Base64.getEncoder().encodeToString(out.toByteArray());
}
/**
* @param data
* @param privateKey
* @throws Exception
* @Description: 私钥解密
*/
public static String decryptByPrivateKey(String data, RSAPrivateKey privateKey)
throws Exception {
Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
cipher.init(Cipher.DECRYPT_MODE, privateKey);
// RSA加密算法的模长 n
int modulusSize = privateKey.getModulus().bitLength() / 8;
byte[] dataBytes = data.getBytes(CHARSET);
// 之前加密的时候做了转码,此处需要使用Base64进行解码
byte[] decodeData = Base64.getDecoder().decode(dataBytes);
// 切分字节数组,每段不大于modulusSize
byte[][] splitArrays = splitArray(decodeData, modulusSize);
ByteArrayOutputStream out = new ByteArrayOutputStream();
for (byte[] arr : splitArrays) {
out.write(cipher.doFinal(arr));
}
return new String(out.toByteArray());
}
/**
* @param data
* @param privateKey
* @throws Exception
* @Description: 公钥加密
*/
public static String encryptByPrivateKey(String data, RSAPrivateKey privateKey)
throws Exception {
Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
cipher.init(Cipher.ENCRYPT_MODE, privateKey);
// 模长n转换成字节数
int modulusSize = privateKey.getModulus().bitLength() / 8;
// PKCS Padding长度为11字节,所以实际要加密的数据不能要 - 11byte
int maxSingleSize = modulusSize - 11;
// 切分字节数组,每段不大于maxSingleSize
byte[][] dataArray = splitArray(data.getBytes(CHARSET), maxSingleSize);
ByteArrayOutputStream out = new ByteArrayOutputStream();
// 分组加密,并将加密后的内容写入输出字节流
for (byte[] s : dataArray) {
out.write(cipher.doFinal(s));
}
// 使用Base64将字节数组转换String类型
return Base64.getEncoder().encodeToString(out.toByteArray());
}
/**
* @param data
* @param publicKey
* @throws Exception
* @Description: 公钥解密
*/
public static String decryptByPublicKey(String data, RSAPublicKey publicKey)
throws Exception {
Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
cipher.init(Cipher.DECRYPT_MODE, publicKey);
// RSA加密算法的模长 n
int modulusSize = publicKey.getModulus().bitLength() / 8;
byte[] dataBytes = data.getBytes(CHARSET);
// 之前加密的时候做了转码,此处需要使用Base64进行解码
byte[] decodeData = Base64.getDecoder().decode(dataBytes);
// 切分字节数组,每段不大于modulusSize
byte[][] splitArrays = splitArray(decodeData, modulusSize);
ByteArrayOutputStream out = new ByteArrayOutputStream();
for (byte[] arr : splitArrays) {
out.write(cipher.doFinal(arr));
}
return new String(out.toByteArray());
}
/**
* @param data
* @param len 单个字节数组长度
* @Description: 按指定长度切分数组
*/
private static byte[][] splitArray(byte[] data, int len) {
int dataLen = data.length;
if (dataLen <= len) {
return new byte[][]{data};
}
byte[][] result = new byte[(dataLen - 1) / len + 1][];
int resultLen = result.length;
for (int i = 0; i < resultLen; i++) {
if (i == resultLen - 1) {
int slen = dataLen - len * i;
byte[] single = new byte[slen];
System.arraycopy(data, len * i, single, 0, slen);
result[i] = single;
break;
}
byte[] single = new byte[len];
System.arraycopy(data, len * i, single, 0, len);
result[i] = single;
}
return result;
}
public static void main(String[] args) throws Exception {
Map<String, String> keyStringList = RSAUtils.getRSAKeyString(1024);
String pukString = CommonConstant.RSA_PUBLIC_KEY;
String prkString = CommonConstant.RSA_PRIVATE_KEY;
System.out.println("公钥:" + pukString);
System.out.println("私钥:" + prkString);
// 生成公钥、私钥
RSAPublicKey puk = RSAUtils.getPublicKey(pukString);
RSAPrivateKey prk = RSAUtils.getPrivateKey(prkString);
String message = "{\"message\":\"test\",\"name\":\"test\"}";
String encryptedMsg = RSAUtils.encryptByPublicKey(message, puk);
String decryptedMsg = RSAUtils.decryptByPrivateKey(encryptedMsg, prk);
System.out.println("未加密内容:" + message);
System.out.println("公钥加密内容:" + encryptedMsg);
System.out.println("私钥解密内容:" + decryptedMsg);
System.out.println("-----------------------------------");
}
}
常用变量类:
import com.alibaba.fastjson.JSONObject;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/**
* @Author: meng
* @Description: 常用变量
* @Date: 2023/3/30 10:29
* @Version: 1.0
*/
@Component
public class CommonConstant {
//JWT密钥
public static final String JWT_TOKEN = "jwt-token";
//请求头中的token
public static final String X_TOKEN = "X-TOKEN";
//请求头中的sign
public static final String X_SIGN = "X-SIGN";
public static final String X_APPID = "X-APPID";
public static final String CODE = "code";
public static final String MESSAGE = "message";
public static final String UTF8 = "UTF-8";
public static final String RSA_PUBLIC_KEY = "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDFJIl4il6nDBlF/3byWB/KXRqfEXkviz7ZvO7TU7JBfh7sFqfgLtJFDSA33+qTHOtYTCjCrwl6oWWX7Aff39HiFW1IBnhKjYdSK5/8ruQY+Y2xbpBMgslA0m2euOv3XPJUXWh0JGBqPllgzvtbtUA1iBELAHVYBACuQPYP2VcPeQIDAQAB";
public static final String RSA_PRIVATE_KEY = "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAMUkiXiKXqcMGUX/dvJYH8pdGp8ReS+LPtm87tNTskF+HuwWp+Au0kUNIDff6pMc61hMKMKvCXqhZZfsB9/f0eIVbUgGeEqNh1Irn/yu5Bj5jbFukEyCyUDSbZ646/dc8lRdaHQkYGo+WWDO+1u1QDWIEQsAdVgEAK5A9g/ZVw95AgMBAAECgYABvRrBR2ciTgcDCQfBh2lwXXXYpUzOUIoTXYk1r+1IipY3OtPsND2CgmUgWQc2mPCybKmHXgfVXwsIVfqTzOOK+PEMVGYNflUdXgV3hNffRzl/nfPdpqhb2ALu8ftPwiGq5QN2PqaRgY9kM67Ye/cCjFzm/kLIqsNuXLKiQc1ioQJBAO7g4ZBcG/D0IxtiR4RdXYtr4wQc+cmscSKj5RPNBwn0bh9psOSg2loS/wWUmCnYSncsLGgMzPl+yPkTLwGryH0CQQDTRduiOzu6bFdOw6tI6eOxHB5h0kfcim4VT/Huh5RyP+GC7kLBmknbBO/tQXxSDVaG81Pkr+INHxJmctfKik+tAkEAtBIrl0IIAhRXnp3wYXRsPtxeLkyVc5SdWEqKNen5Y2Sx2tY2dbJXx0zIl3FTXz/fqoRPGUSFA5Kydygh6DWRlQJBAMmOfOHB9tJ8Z7LJ85AFKucdt1KFpW8eVbVZZqq0iAeTMBaULfW7tzgO9sJ3Vh6FgQYP//pNXbA883XvnDUrTKUCQQDgLO7mThmy7iqqo0be4a2ycy9fvORFYzSq1t6mTd+gr73CMCy2bTmyv/Qp4QsuPIKea0iE+HA/la5zlM8eAxOq";
//公共返回方法
public static Mono<Void> buildResponse(ServerWebExchange exchange, int code, String message) {
ServerHttpResponse response = exchange.getResponse();
response.setStatusCode(HttpStatus.OK);
response.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
JSONObject jsonObject = new JSONObject();
jsonObject.put(CODE, code);
jsonObject.put(MESSAGE, message);
DataBuffer bodyDataBuffer = response.bufferFactory().wrap(jsonObject.toJSONString().getBytes());
return response.writeWith(Mono.just(bodyDataBuffer));
}
}
PS:我们想要实现对请求参数解密,需要解决获取body参数,只能获取一次的问题,网上有很多解决方案了,大家可以自行搜索。
推荐一个:SpringCloud-Gateway获取body参数,解决只能获取一次问题,终极解决方案_kamjin1996的博客-CSDN博客
不想看上边的博客,直接使用如下代码即可:
import com.example.gateway.common.GatewayContext;
import io.netty.buffer.ByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
/**
* @Author: https://blog.csdn.net/zx156955/article/details/115004910
* @Description: 请求内容存储 处理请求内容 内容放在gatewayContext中
* @Date: 2023/3/30 10:11
* @Version: 1.0
*/
@Component
@Slf4j
public class RequestCoverFilter implements GlobalFilter, Ordered {
/**
* default HttpMessageReader
*/
private static final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();
/**
* ReadFormData
*
* @param exchange
* @param chain
* @return
*/
private Mono<Void> readFormData(ServerWebExchange exchange, GatewayFilterChain chain,
GatewayContext gatewayContext) {
final ServerHttpRequest request = exchange.getRequest();
HttpHeaders headers = request.getHeaders();
return exchange.getFormData().doOnNext(multiValueMap -> {
gatewayContext.setFormData(multiValueMap);
log.debug("[GatewayContext]Read FormData:{}", multiValueMap);
}).then(Mono.defer(() -> {
Charset charset = headers.getContentType().getCharset();
charset = charset == null ? StandardCharsets.UTF_8 : charset;
String charsetName = charset.name();
MultiValueMap<String, String> formData = gatewayContext.getFormData();
/**
* formData is empty just return
*/
if (null == formData || formData.isEmpty()) {
return chain.filter(exchange);
}
StringBuilder formDataBodyBuilder = new StringBuilder();
String entryKey;
List<String> entryValue;
try {
/**
* repackage form data
*/
for (Map.Entry<String, List<String>> entry : formData.entrySet()) {
entryKey = entry.getKey();
entryValue = entry.getValue();
if (entryValue.size() > 1) {
for (String value : entryValue) {
formDataBodyBuilder.append(entryKey).append("=")
.append(URLEncoder.encode(value, charsetName)).append("&");
}
} else {
formDataBodyBuilder.append(entryKey).append("=")
.append(URLEncoder.encode(entryValue.get(0), charsetName)).append("&");
}
}
} catch (UnsupportedEncodingException e) {
// ignore URLEncode Exception
}
/**
* substring with the last char '&'
*/
String formDataBodyString = "";
if (formDataBodyBuilder.length() > 0) {
formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);
}
/**
* get data bytes
*/
byte[] bodyBytes = formDataBodyString.getBytes(charset);
int contentLength = bodyBytes.length;
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(request) {
/**
* change content-length
*
* @return
*/
@Override
public HttpHeaders getHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(super.getHeaders());
if (contentLength > 0) {
httpHeaders.setContentLength(contentLength);
} else {
httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
}
return httpHeaders;
}
/**
* read bytes to Flux<Databuffer>
*
* @return
*/
@Override
public Flux<DataBuffer> getBody() {
return DataBufferUtils.read(new ByteArrayResource(bodyBytes),
new NettyDataBufferFactory(ByteBufAllocator.DEFAULT), contentLength);
}
};
ServerWebExchange mutateExchange = exchange.mutate().request(decorator).build();
log.info("[GatewayContext]Rewrite Form Data :{}", formDataBodyString);
return chain.filter(mutateExchange);
}));
}
/**
* ReadJsonBody
*
* @param exchange
* @param chain
* @return
*/
private Mono<Void> readBody(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
/**
* join the body
*/
return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {
/*
* read the body Flux<DataBuffer>, and release the buffer
* see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
*/
byte[] bytes = new byte[dataBuffer.readableByteCount()];
dataBuffer.read(bytes);
DataBufferUtils.release(dataBuffer);
Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
DataBufferUtils.retain(buffer);
return Mono.just(buffer);
});
/**
* repackage ServerHttpRequest
*/
ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return cachedFlux;
}
};
/**
* mutate exchage with new ServerHttpRequest
*/
ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
/**
* read body string with default messageReaders
*/
return ServerRequest.create(mutatedExchange, messageReaders).bodyToMono(String.class)
.doOnNext(objectValue -> {
gatewayContext.setCacheBody(objectValue);
log.debug("[GatewayContext]Read JsonBody:{}", objectValue);
}).then(chain.filter(mutatedExchange));
});
}
@Override
public int getOrder() {
return HIGHEST_PRECEDENCE;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
/**
* save request path and serviceId into gateway context
*/
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
GatewayContext gatewayContext = new GatewayContext();
String path = request.getPath().pathWithinApplication().value();
gatewayContext.setPath(path);
gatewayContext.getFormData().addAll(request.getQueryParams());
gatewayContext.setIpAddress(String.valueOf(request.getRemoteAddress()));
HttpHeaders headers = request.getHeaders();
gatewayContext.setHeaders(headers);
log.debug("HttpMethod:{},Url:{}", request.getMethod(), request.getURI().getRawPath());
/// 注意,因为webflux的响应式编程 不能再采取原先的编码方式 即应该先将gatewayContext放入exchange中,否则其他地方可能取不到
/**
* save gateway context into exchange
*/
exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
// 处理参数
MediaType contentType = headers.getContentType();
long contentLength = headers.getContentLength();
if (contentLength > 0) {
if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_JSON_UTF8.equals(contentType)) {
return readBody(exchange, chain, gatewayContext);
}
if (MediaType.APPLICATION_FORM_URLENCODED.equals(contentType)) {
return readFormData(exchange, chain, gatewayContext);
}
}
log.debug("[GatewayContext]ContentType:{},Gateway context is set with {}", contentType, gatewayContext);
return chain.filter(exchange);
}
}
网关上下文:
import lombok.Data;
import org.springframework.http.HttpHeaders;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @Author: meng
* @Description: 网关上下文
* @Version: 1.0
*/
@Data
public class GatewayContext {
public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";
/**
* cache headers
*/
private HttpHeaders headers;
/**
* cache json body
*/
private String cacheBody;
/**
* cache formdata
*/
private MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
/**
* ipAddress
*/
private String ipAddress;
/**
* path
*/
private String path;
}
三、 RSA实现对请求参数解密
1、修改yml文件,增加解密过滤器
2、RSA解密过滤器
import cn.hutool.core.util.StrUtil;
import com.example.gateway.common.CommonConstant;
import com.example.gateway.common.GatewayContext;
import com.example.gateway.utils.RSAUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import java.security.interfaces.RSAPrivateKey;
/**
* @Author: meng
* @Description: RSA实现对请求参数解密
* @Date: 2023/4/6 15:20
* @Version: 1.0
*/
@Slf4j
@Component
public class RSADecryptResponseGatewayFilterFactory extends AbstractGatewayFilterFactory {
@Override
public GatewayFilter apply(Object config) {
return (exchange, chain) -> {
ServerHttpRequest serverHttpRequest = exchange.getRequest();
HttpHeaders header = serverHttpRequest.getHeaders();
String decrypt = serverHttpRequest.getHeaders().getFirst("decrypt");
if (!HttpMethod.POST.matches(serverHttpRequest.getMethodValue())) {
return chain.filter(exchange);
}
byte[] decrypBytes;
GatewayContext gatewayContext = exchange.getAttribute(GatewayContext.CACHE_GATEWAY_CONTEXT);
if(StrUtil.isBlank(gatewayContext.getCacheBody())){
return CommonConstant.buildResponse(exchange, HttpStatus.BAD_REQUEST.value(), "请求参数不能为空");
}
try {
// 获取request body
String requestBody = gatewayContext.getCacheBody();
log.info("encryptMsg body :{}", requestBody);
RSAPrivateKey privateKey = RSAUtils.getPrivateKey(CommonConstant.RSA_PRIVATE_KEY);
String decryptMsg = RSAUtils.decryptByPrivateKey(requestBody, privateKey);
log.info("decryptMsg body :{}", decryptMsg);
gatewayContext.setCacheBody(decryptMsg);
decrypBytes = decryptMsg.getBytes();
} catch (Exception e) {
log.error("RSA 解密失败:{}", e);
return CommonConstant.buildResponse(exchange, HttpStatus.BAD_REQUEST.value(), "RSA解密失败");
}
// 根据解密后的参数重新构建请求
DataBufferFactory dataBufferFactory = exchange.getResponse().bufferFactory();
Flux<DataBuffer> bodyFlux = Flux.just(dataBufferFactory.wrap(decrypBytes));
ServerHttpRequest newRequest = serverHttpRequest.mutate().uri(serverHttpRequest.getURI()).build();
newRequest = new ServerHttpRequestDecorator(newRequest) {
@Override
public Flux<DataBuffer> getBody() {
return bodyFlux;
}
};
// 构建新的请求头
HttpHeaders headers = new HttpHeaders();
headers.putAll(exchange.getRequest().getHeaders());
// 由于修改了传递参数,需要重新设置CONTENT_LENGTH,长度是字节长度,不是字符串长度
int length = decrypBytes.length;
headers.remove(HttpHeaders.CONTENT_LENGTH);
headers.setContentLength(length);
newRequest = new ServerHttpRequestDecorator(newRequest) {
@Override
public HttpHeaders getHeaders() {
return headers;
}
};
// 把解密后的数据重置到exchange自定义属性中,在之后的日志GlobalLogFilter从此处获取请求参数打印日志
exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
return chain.filter(exchange.mutate().request(newRequest).build());
};
}
}
3、测试结果:
四、 RSA实现对返回结果加密
最简单的方法是使用gateway全局过滤器实现,但是我们有可能不需要对所有接口的结果进行加密,我们也可以使用局部过滤器实现。
实践过程中遇到一个问题,ServerHttpResponseDecorator不生效,是因为过滤器顺序问题。
全局过滤器:
自定义的GlobaFilter的order必须小于-1
,否则标准 NettyWriteResponseFilter
将在过滤器有机会被调用之前发送响应。
相关讨论:https://github.com/spring-cloud/spring-cloud-gateway/issues/47
局部过滤器:
官方提供了一种方式,详细介绍:Spring Cloud Gateway
如果我们想使用yml+nacos实现动态路由的形式,我们可以借鉴原生类ModifyRequestBodyGatewayFilterFactory实现。
1、修改yml文件,增加加密过滤器
2、RSA加密过滤器
import cn.hutool.core.util.StrUtil;
import com.example.gateway.common.CommonConstant;
import com.example.gateway.utils.RSAUtils;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.nio.charset.Charset;
import java.security.interfaces.RSAPrivateKey;
/**
* @Author: meng
* @Description: RSA加密 - 借鉴原生类ModifyRequestBodyGatewayFilterFactory实现
* https://docs.spring.io/spring-cloud-gateway/docs/current/reference/html/#the-modifyresponsebody-gatewayfilter-factory
* @Date: 2023/4/3 11:37
* @Version: 1.0
*/
@Slf4j
@Component
public class RSAEncryptResponseGatewayFilterFactory extends AbstractGatewayFilterFactory {
@Override
public GatewayFilter apply(Object config) {
RewriteResponseGatewayFilter rewriteResponseGatewayFilter = new RewriteResponseGatewayFilter();
return rewriteResponseGatewayFilter;
}
public class RewriteResponseGatewayFilter implements GatewayFilter, Ordered {
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
if (!HttpMethod.POST.matches(exchange.getRequest().getMethodValue())) {
return chain.filter(exchange);
}
ServerHttpResponse originalResponse = exchange.getResponse();
DataBufferFactory bufferFactory = originalResponse.bufferFactory();
ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(originalResponse) {
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
if (body instanceof Flux) {
Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
return super.writeWith(fluxBody.map(dataBuffer -> {
byte[] content = new byte[dataBuffer.readableByteCount()];
dataBuffer.read(content);
DataBufferUtils.release(dataBuffer);
String responseStr = new String(content, Charset.forName(CommonConstant.UTF8));
//RSA加密response值
log.info("RSA response:{}", responseStr);
byte[] newContent = new byte[0];
try {
RSAPrivateKey privateKey = RSAUtils.getPrivateKey(CommonConstant.RSA_PRIVATE_KEY);
String encryptMsg = RSAUtils.encryptByPrivateKey(responseStr, privateKey);
log.info("RSA encryptMsg:{}", encryptMsg);
newContent = encryptMsg.getBytes(CommonConstant.UTF8);
} catch (Exception e) {
log.error("RSA fail:{}", e);
throw new RuntimeException(e);
}
return bufferFactory.wrap(newContent);
}));
}
return super.writeWith(body);
}
};
return chain.filter(exchange.mutate().response(decoratedResponse).build());
}
@Override
public int getOrder() {
return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 1;
}
}
}
3、测试结果