背景
mybatis中四大组件的作用,下面开发的插件拦截器会使用
四大组件Executor、StatementHandler、ParameterHandler、ResultSetHandler
需求
1、根据脱敏规则进行查询数据,显示的时候进行展示脱敏
2、根据脱敏规则进行查询数据,将脱敏后的数据批量更新回数据库,进行脱敏存储
数据
/*
Navicat Premium Data Transfer
Source Server : localhost
Source Server Type : MySQL
Source Server Version : 80026 (8.0.26)
Source Host : localhost:3306
Source Schema : mp
Target Server Type : MySQL
Target Server Version : 80026 (8.0.26)
File Encoding : 65001
Date: 18/08/2024 13:44:24
*/
SET NAMES utf8mb4;
SET FOREIGN_KEY_CHECKS = 0;
-- ----------------------------
-- Table structure for user
-- ----------------------------
DROP TABLE IF EXISTS `user`;
CREATE TABLE `user` (
`id` bigint NOT NULL AUTO_INCREMENT COMMENT '用户id',
`username` varchar(50) CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL COMMENT '用户名',
`password` varchar(128) CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL COMMENT '密码',
`phone` varchar(20) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '注册手机号',
`info` json NOT NULL COMMENT '详细信息',
`status` int NULL DEFAULT 1 COMMENT '使用状态(1正常 2冻结)',
`balance` int NULL DEFAULT NULL COMMENT '账户余额',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
PRIMARY KEY (`id`) USING BTREE,
UNIQUE INDEX `username`(`username` ASC) USING BTREE
) ENGINE = InnoDB AUTO_INCREMENT = 5 CHARACTER SET = utf8 COLLATE = utf8_general_ci COMMENT = '用户表' ROW_FORMAT = COMPACT;
-- ----------------------------
-- Records of user
-- ----------------------------
INSERT INTO `user` VALUES (1, 'Jack', 'D5F3C4C80A651CFE876DA231061FA871', '13900112224', '{\"age\": 20, \"intro\": \"佛系青年\", \"gender\": \"male\"}', 1, 1600, '2023-05-19 20:50:21', '2024-08-18 13:43:05');
INSERT INTO `user` VALUES (2, 'Rose', 'D5F3C4C80A651CFE876DA231061FA871', '13900112223', '{\"age\": 19, \"intro\": \"青涩少女\", \"gender\": \"female\"}', 1, 600, '2023-05-19 21:00:23', '2024-08-18 13:43:09');
INSERT INTO `user` VALUES (3, 'Hope', 'D5F3C4C80A651CFE876DA231061FA871', '13900112222', '{\"age\": 25, \"intro\": \"上进青年\", \"gender\": \"male\"}', 1, 100000, '2023-06-19 22:37:44', '2024-08-18 13:43:12');
INSERT INTO `user` VALUES (4, 'Toomas', 'D5F3C4C80A651CFE876DA231061FA871', '17701265258', '{\"age\": 29, \"intro\": \"伏地魔\", \"gender\": \"male\"}', 1, 800, '2023-06-19 23:44:45', '2024-08-18 13:43:27');
SET FOREIGN_KEY_CHECKS = 1;
具体实现
1、定义脱敏规则
mybatis:
interceptors:
# 是否开启拦截
enabled: true
desensitization:
# 是否进行脱敏
enabled: true
rule:
# 脱敏模式
- schema: MP
# 是否全模式脱敏
allSchema: true
# 单模式脱敏规则
schemaRule:
# 模式下需要脱敏的表
- tableName: USER
# 模式下需要脱敏的表的主键字段
keyColumn: id
# 具体表的脱敏规则
tableRule:
# 具体列的脱敏规则
- column: username
columnRule:
# 隐藏脱敏,*代替
ruleType: hide
startIndex: 1
endIndex: 2
- column: password
columnRule:
# 密码aes加密脱敏
ruleType: aes
# 全模式脱敏规则
allTableRule:
- column: username
columnRule:
ruleType: hide
startIndex: 1
endIndex: 2
- column: password
columnRule:
# 密码aes加密脱敏
ruleType: aes
2、配置类
/**
* @author code
* @version 1.0
* @Date 2024/8/16 13:35
* @Description ${DESCRIPTION}
*/
@Data
@Component
@ConfigurationProperties(prefix = "desensitization")
public class DesensitizationConfig {
private String enabled = "false";
private List<DesensitizationBaseRule> rule;
}
/**
* @author code
* @version 1.0
* @Date 2024/8/16 13:37
* @Description 脱敏规则
*/
@Data
@EqualsAndHashCode(callSuper = false)
@NoArgsConstructor
public class DesensitizationBaseRule implements Serializable {
private static final long serialVersionUID = 1L;
private String schema = "schema";
private String allSchema = "false";
private List<DesensitizationSchemaRule> schemaRule = new ArrayList<>();
private List<DesensitizationTableRule> allTableRule = new ArrayList<>();
}
/**
* @author code
* @version 1.0
* @Date 2024/8/16 13:41
* @Description ${DESCRIPTION}
*/
@Data
@EqualsAndHashCode(callSuper = false)
@NoArgsConstructor
public class DesensitizationSchemaRule implements Serializable {
private static final long serialVersionUID = 1L;
private String tableName = "table";
private String keyColumn = "id";
private List<DesensitizationTableRule> tableRule;
}
/**
* @author code
* @version 1.0
* @Date 2024/8/16 13:46
* @Description ${DESCRIPTION}
*/
@Data
@EqualsAndHashCode(callSuper = false)
@NoArgsConstructor
public class DesensitizationTableRule implements Serializable {
private static final long serialVersionUID = 1L;
private String column = "name";
private DesensitizationMaskRule columnRule;
}
/**
* @author code
* @version 1.0
* @Date 2024/8/16 13:50
* @Description ${DESCRIPTION}
*/
@Data
@EqualsAndHashCode(callSuper = false)
@NoArgsConstructor
public class DesensitizationMaskRule implements Serializable {
private static final long serialVersionUID = 1L;
private String ruleType = "hide";
private int startIndex = 1;
private int endIndex = 2;
}
3、拦截器
/**
* @author code
* @version 1.0
* @Date 2024/8/13 9:54
* @Description ${DESCRIPTION}
*/
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Slf4j
public class SqlQueryInterceptor implements Interceptor {
private static final Pattern TABLE_NAME_PATTERN = Pattern.compile("(FROM|UPDATE)\\s+([\\w\\._]+)");
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
// 获取连接对象
Connection connection = (Connection) invocation.getArgs()[0];
// 获取数据库名称
String databaseName = connection.getCatalog();
log.info("----------当前数据库名:{}", databaseName);
Object parameterObject = statementHandler.getParameterHandler().getParameterObject();
JSONObject paramObject = JSONObject.parseObject(JSON.toJSONString(parameterObject));
BoundSql boundSql = statementHandler.getBoundSql();
//统一转大写
String tableNameSql = boundSql.getSql().toUpperCase();
//获取表名
String tableName = "";
Matcher matcher = TABLE_NAME_PATTERN.matcher(tableNameSql);
while (matcher.find()){
tableName = matcher.group(2);
log.info("----------当前sql语句:{},表名:{}", tableNameSql, tableName);
}
if(StringUtils.hasLength(databaseName) && StringUtils.hasLength(tableName)){
ContextUtil.setSchemaName(databaseName);
ContextUtil.setTableName(tableName);
}
String originalSql = boundSql.getSql();
String modifySql = modifySql(originalSql, paramObject);
Field field = BoundSql.class.getDeclaredField("sql");
field.setAccessible(true);
field.set(statementHandler.getBoundSql(), modifySql);
log.info("----------修改后的sql:{}", statementHandler.getBoundSql().getSql());
return invocation.proceed();
}
private String modifySql(String originalSql, JSONObject paramObject){
String tableName = "";
String fieldName = "";
String fieldCode = "";
if(ObjectUtil.isNotEmpty(paramObject)){
tableName = paramObject.getString("tablename");
fieldName = paramObject.getString("fieldname");
fieldCode = paramObject.getString("fieldcode");
}
if(StringUtils.hasLength(tableName)){
log.info("tableName:{}", tableName);
}
if(StringUtils.hasLength(fieldName)){
log.info("fieldName:{}", fieldName);
}
if(StringUtils.hasLength(fieldCode)){
log.info("fieldCode:{}", fieldCode);
}
log.info("modify sql:{}", originalSql);
return originalSql;
}
}
/**
* @author code
* @version 1.0
* @Date 2024/8/13 14:15
* @Description ${DESCRIPTION}
*/
@Intercepts({
@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})
})
@Slf4j
public class ResultSetInterceptor implements Interceptor {
private static final String TABLE_NAME = ".";
private static final String PATTERN_TABLE_NAME = "\\.";
@Resource
private DesensitizationConfig desensitizationConfig;
@Resource
private BaseDesensitizationPolicy baseDesensitizationPolicy;
@Override
public Object plugin(Object target) {
if (target instanceof ResultSetHandler) {
return Plugin.wrap(target, this);
}
return target;
}
@Override
public void setProperties(Properties properties) {
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object resultObject = invocation.proceed();
if (Objects.isNull(resultObject)) {
return null;
}
String schemaName = ContextUtil.getSchemaName();
String tableName = ContextUtil.getTableName();
if (tableName.contains(TABLE_NAME)) {
String[] schemaAndTable = tableName.split(PATTERN_TABLE_NAME);
schemaName = schemaAndTable[0];
tableName = schemaAndTable[1];
}
log.info("----------拦截的数据库:{},表:{}", schemaName, tableName);
log.info("----------当前的脱敏规则:{}", desensitizationConfig.toString());
//脱敏处理
if (Boolean.FALSE.toString().equals(desensitizationConfig.getEnabled())) {
//不进行脱敏
return resultObject;
}
//最终变量
final String schema = schemaName;
final String table = tableName;
List<DesensitizationBaseRule> rule = desensitizationConfig.getRule();
//获取当前schema的规则
DesensitizationBaseRule desensitizationBaseRule = rule.stream().filter(d -> schema.equalsIgnoreCase(d.getSchema())).findFirst().orElse(null);
//定义脱敏规则
List<DesensitizationTableRule> columnRule = new ArrayList<>();
//是否全表脱敏规则
if (Boolean.TRUE.toString().equals(desensitizationBaseRule.getAllSchema())) {
columnRule = desensitizationBaseRule.getAllTableRule();
} else {
List<DesensitizationSchemaRule> schemaRule = desensitizationBaseRule.getSchemaRule();
DesensitizationSchemaRule desensitizationSchemaRule = schemaRule.stream().filter(s -> table.equalsIgnoreCase(s.getTableName())).findFirst().orElse(null);
columnRule = Objects.isNull(desensitizationSchemaRule.getTableRule()) ? null : desensitizationSchemaRule.getTableRule();
}
//分页结果
if (resultObject instanceof IPage) {
IPage page = (IPage) resultObject;
if (page != null && CollUtil.isNotEmpty(page.getRecords())) {
for (Object res : page.getRecords()) {
doDealData(res, columnRule);
}
}
return resultObject;
}
//list结果
if (resultObject instanceof ArrayList) {
ArrayList list = (ArrayList) resultObject;
if (CollUtil.isNotEmpty(list)) {
for (Object res : list) {
doDealData(res, columnRule);
}
}
return resultObject;
}
//单个接口
doDealData(resultObject, columnRule);
return resultObject;
}
private Object doDealData(Object result, List<DesensitizationTableRule> columnRule) throws Exception {
//result to map
Map<Object, Object> map = (HashMap<Object, Object>) result;
String jsonString = JSON.toJSONString(map);
log.info("当前的json数据:{}, 脱敏规则:{}", jsonString, columnRule.toString());
//遍历规则
columnRule.forEach(c -> {
Object value = map.get(c.getColumn());
BaseDesensitizationService service = baseDesensitizationPolicy.getService(c.getColumnRule().getRuleType());
String hideValue = service.desensitizationStr(value.toString(), c.getColumnRule());
log.info("当前的字段:{}, 脱敏规则:{},脱敏后的数据:{}", c.getColumn(), c.getColumnRule().toString(), hideValue);
map.put(c.getColumn(), hideValue);
});
return map;
}
}
4、注册拦截器
@Bean
@Order(1)
public SqlQueryInterceptor sqlQueryInterceptor(){
return new SqlQueryInterceptor();
}
@Bean
@Order(2)
public ResultSetInterceptor resultSetInterceptor(){
return new ResultSetInterceptor();
}
5、工具类
/**
* @Description
* @Author code
* @Create 2024-08-18 11:02
*/
@Slf4j
public class AESUtil {
/**
* 定义 aes 加密的key
* 密钥 必须是16位, 自定义,
* 如果不是16位, 则会出现InvalidKeyException: Illegal key size
* 解决方案有两种:
* 需要安装Java Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files(可以在Oracle下载).
* .设置设置key的长度为16个字母和数字的字符窜(128 Bit/8=16字符)就不报错了。
*/
//秘钥设置16位
private static final String SECRET_KEY = "abcdeg123456!#$%";
public static String getSecretKey(){
return SECRET_KEY;
}
//加密方式
private static final String AES_TYPE = "AES";
//秘钥最大长度16位
private static final int BYTE_LENGTH = 16;
/**
* 加密AES
*
* @param value 字符串
* @param key 秘钥
* @return String
*/
private static String encryptAES(String key, String value) {
try {
byte[] keyBytes = Arrays.copyOf(key.getBytes(StandardCharsets.US_ASCII), BYTE_LENGTH);
SecretKey keyStr = new SecretKeySpec(keyBytes, AES_TYPE);
Cipher cipher = Cipher.getInstance(AES_TYPE);
cipher.init(Cipher.ENCRYPT_MODE, keyStr);
byte[] cleartext = value.getBytes(StandardCharsets.UTF_8);
byte[] ciphertextBytes = cipher.doFinal(cleartext);
return new String(Hex.encodeHex(ciphertextBytes)).toUpperCase();
} catch (Exception ex) {
ex.printStackTrace();
}
return null;
}
/**
* 解密AES
*
* @param encrypted 字符串
* @param key 秘钥
* @return String
*/
private static String decryptAES(String key, String encrypted) {
try {
byte[] keyBytes = Arrays.copyOf(key.getBytes(StandardCharsets.US_ASCII), BYTE_LENGTH);
SecretKey keyStr = new SecretKeySpec(keyBytes, AES_TYPE);
Cipher cipher = Cipher.getInstance(AES_TYPE);
cipher.init(Cipher.DECRYPT_MODE, keyStr);
byte[] content = Hex.decodeHex(encrypted.toCharArray());
byte[] ciphertextBytes = cipher.doFinal(content);
return new String(ciphertextBytes);
} catch (Exception ex) {
ex.printStackTrace();
}
return null;
}
/**
* 加密
*
* @param value 字符串
* @param pwd 秘钥
* @return String
*/
public static String encryptValue(String value, String pwd) {
return StrUtil.isEmpty(value) ? null : encryptAES(pwd, value.trim());
}
/**
* 解密
*
* @param value 字符串
* @param pwd 秘钥
* @return String
*/
public static String decryptValue(String value, String pwd) {
if (StrUtil.isNotEmpty(value)) {
String value2 = decryptAES(pwd, value.toLowerCase());
value = StrUtil.isEmpty(value2) ? value : value2;
}
return value;
}
/**
* main方法
*
* #mysql 进行AES加密
* SELECT HEX(AES_ENCRYPT("原始字符串","密钥"));
*
* #mysql 进行AES解密
* SELECT AES_DECRYPT(UNHEX("加密后字符串"), "密钥");
*
* @param args 参数
*/
public static void main(String[] args) {
//加密测试
System.out.println(encryptValue("张三", ""));
//解密测试
System.out.println(decryptValue(encryptValue("张三", ""), ""));
//加密测试
System.out.println(encryptValue("李四", SECRET_KEY));
//解密测试
System.out.println(decryptValue(encryptValue("李四", SECRET_KEY), SECRET_KEY));
}
}
/**
* @author code
* @version 1.0
* @Date 2024/8/13 10:05
* @Description ${DESCRIPTION}
*/
public class ContextUtil {
private static final ThreadLocal<String> THREAD_LOCAL_SCHEMA_NAME = new ThreadLocal<>();
private static final ThreadLocal<String> THREAD_LOCAL_TABLE_NAME = new ThreadLocal<>();
public static void setTableName(String tableName){
THREAD_LOCAL_TABLE_NAME.set(tableName);
}
public static String getTableName(){
return THREAD_LOCAL_TABLE_NAME.get();
}
public static void removerTableName(){
THREAD_LOCAL_TABLE_NAME.remove();
}
public static void setSchemaName(String schemaName){
THREAD_LOCAL_SCHEMA_NAME.set(schemaName);
}
public static String getSchemaName(){
return THREAD_LOCAL_SCHEMA_NAME.get();
}
public static void removerSchemaName(){
THREAD_LOCAL_SCHEMA_NAME.remove();
}
}
/**
* @author code
* @version 1.0
* @Date 2024/6/17 16:26
* @Description ${DESCRIPTION}
*/
@Component
@Slf4j
public class JdbcSqlUtil {
private static final String TABLE_NAME = ".";
private static final String PATTERN_TABLE_NAME = "\\.";
@Resource
private DataSource dataSource;
@Resource
private DesensitizationConfig desensitizationConfig;
public Map<String, Object> getDataForMap(String sql) {
long startTime = System.currentTimeMillis();
log.info("-------------------------executeQuerySqlMap开始时间:{}", startTime);
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
log.info("executeQuerySqlMap:{}", sql);
Map<String, Object> maps = jdbcTemplate.queryForMap(sql);
long endTime = System.currentTimeMillis();
log.info("-------------------------executeQuerySqlMap结束时间:{},消耗时间:{}", endTime, endTime - startTime);
return maps;
}
/**
* 通用sql查询,支持一般函数、join、分页
* 执行sql语句
*
* @param commonSql
* @return
*/
public List<Map<String, Object>> getCommonDataList(String commonSql) {
long startTime = System.currentTimeMillis();
log.info("-------------------------executeQuerySql开始时间:{}", startTime);
String sql = sqlFormat(commonSql);
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
log.info("common-------query-------executeQuerySql:{}", sql);
List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql);
long endTime = System.currentTimeMillis();
log.info("common-------query-------executeQuerySql结束时间:{},消耗时间:{}", endTime, endTime - startTime);
return maps;
}
public Object butchUpdateData(List<Map<String, Object>> data) {
if (CollectionUtils.isEmpty(data)) {
return "无更新数据";
}
//判断规则是否开启
//脱敏处理
if (Boolean.FALSE.toString().equals(desensitizationConfig.getEnabled())) {
//不进行脱敏
return "不进行脱敏";
}
String schemaName = ContextUtil.getSchemaName();
String tableName = ContextUtil.getTableName();
if (tableName.contains(TABLE_NAME)) {
String[] schemaAndTable = tableName.split(PATTERN_TABLE_NAME);
schemaName = schemaAndTable[0];
tableName = schemaAndTable[1];
}
//处理模式名和表名
ContextUtil.removerSchemaName();
ContextUtil.removerTableName();
log.info("----------拦截的数据库:{},表:{}", schemaName, tableName);
log.info("----------处理后的数据库:{},表:{}", ContextUtil.getSchemaName(), ContextUtil.getTableName());
//查找主键字段
final String schema = schemaName;
final String table = tableName;
List<DesensitizationBaseRule> rule = desensitizationConfig.getRule();
//获取当前schema的规则
DesensitizationBaseRule desensitizationBaseRule = rule.stream().filter(d -> schema.equalsIgnoreCase(d.getSchema())).findFirst().orElse(null);
//定义脱敏规则
List<DesensitizationTableRule> columnRule = new ArrayList<>();
//是否全表脱敏规则
if (Boolean.TRUE.toString().equals(desensitizationBaseRule.getAllSchema())) {
return "全模式脱敏,无法批量更新";
}
List<DesensitizationSchemaRule> schemaRule = desensitizationBaseRule.getSchemaRule();
DesensitizationSchemaRule desensitizationSchemaRule = schemaRule.stream().filter(s -> table.equalsIgnoreCase(s.getTableName())).findFirst().orElse(null);
String keyColumn = desensitizationSchemaRule.getKeyColumn();
List<DesensitizationTableRule> tableRule = desensitizationSchemaRule.getTableRule();
//构造sql
StringBuilder columns = new StringBuilder();
tableRule.forEach(d -> {
columns.append(d.getColumn()).append("=?,");
});
//去掉最后一个逗号
String column = columns.toString().substring(0, columns.length() - 1);
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
DataSource dataSource = jdbcTemplate.getDataSource();
String updateSql = " update " + schema + "." + table + " set " + column + " where " + keyColumn + " = ?";
BatchSqlUpdate bsu = new BatchSqlUpdate(dataSource, updateSql);
bsu.setBatchSize(1000);
//参数类型匹配
int[] types = new int[tableRule.size() + 1];
for (int i = 0; i <tableRule.size() + 1; i++) {
types[i] = Types.VARCHAR;
}
bsu.setTypes(types);
data.forEach(m -> {
Object[] objects = new Object[tableRule.size() + 1];
for (int i = 0; i < tableRule.size(); i++) {
String key = tableRule.get(i).getColumn();
objects[i] = m.get(key);
}
int keyIndex = tableRule.size();
objects[keyIndex] = m.get(keyColumn);
bsu.update(objects);
});
bsu.flush();
return "OK";
}
//格式化sql
private String sqlFormat(String sql) {
String type = null;
try {
Connection connection = dataSource.getConnection();
type = connection.getMetaData().getDatabaseProductName().toLowerCase();
connection.close();
} catch (SQLException e) {
log.error("db type error");
}
String res = SQLUtils.format(sql, type);
log.info("format sql : {}", res);
return res;
}
}
6、结果展示
原始数据
启用脱敏规则后查询
脱敏后更新数据库数据
需要调用jdbc的批量更新方法,jdbcSqlUtil.butchUpdateData()
需要设置指定模式、表和主键等