Mybatis配置拦截器实现参数加解密+注解

 利用Mybatis拦截器+反射机制,设计加解密注解,可以对特定字段入库出库时,实现自动加解密。
①拦截器原理详见Mybatis插件系列一:拦截器的基础知识

 

本文是使用Executor 插入跟更新操作是在一起的。

package com.config;

import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.annotation.DataSecurity;
import com.util.encrypt.service.ApiEncryptService;
import com.util.ruoyi.utils.bean.BeanUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Base64Utils;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * 自定义 Mybatis 数据安全处理拦截器类。
 */
@Slf4j
@Component
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
public class CustomizeMybatisDataSecurityHandlerInterceptor implements Interceptor {
    /**
     * 加密服务管理对象。
     */
    @Autowired
    private ApiEncryptService mApiEncryptService;
    /**
     * 查询条件 key 名称。
     */
    private static final String CRITERIA_KEY_NAME = "oredCriteria";
    /**
     * 获取条件集合 key 名称。
     */
    private static final String CRITERIA_GET_KEY_NAME = "getCriteria";
    /**
     * 获取条件集合下的条件 key 名称。
     */
    private static final String CRITERIA_GET_CONDITION_KEY_NAME = "getCondition";
    /**
     * 条件集合下的条件值 key 名称。
     */
    private static final String CRITERIA_VALUE_KEY_NAME = "value";
    /**
     * 获取条件集合下的条件值 key 名称。
     */
    private static final String CRITERIA_GET_VALUE_KEY_NAME = "getValue";
    /**
     * 方法对象缓存(便于下次快速获取对象)。<br/>
     * key:Mapper 中的方法(包括包路径)。<br/>
     * value:对应的方法对象。
     */
    private static final Map<String, Method> methodCache = new HashMap<>();
    /**
     * 需要排除的数据类型集合。
     */
    private static final List<Class> excludedDataTypeList = Arrays.asList(String.class, Integer.class, Boolean.class, Double.class, Float.class, Long.class, Short.class, Byte.class, Character.class);

    /***
     * 实现 intercept 方法,该方法将传递 Invocation 对象作为参数。可以使用该对象调用原始方法,并对其返回值进行处理。
     * @param invocation
     * @return
     * @throws Throwable
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] _args = invocation.getArgs();
        if (_args != null && _args.length > 1) {
            MappedStatement _ms = (MappedStatement) _args[0];
            Object _parameter = _args[1]; // 可能是实体对象、Map 对象
            if (_ms != null && _parameter != null) {
                String _orgMethodId = _ms.getId();
                // 针对 Example 的查询方式,会自动生成 xxxxByExample_COUNT 方法,故这里进行处理
                String _keyString = "ByExample_COUNT";
                if (_orgMethodId.endsWith(_keyString)) {
                    _keyString = "_COUNT";
                    _orgMethodId = _orgMethodId.substring(0, _orgMethodId.length() - _keyString.length());
                }
                Method _orgMethod = methodCache.get(_orgMethodId);
                DataSecurity _orgMethodDataSecurity = null;
                if (_orgMethod == null) {
                    String _mapperClassName = _orgMethodId.substring(0, _orgMethodId.lastIndexOf("."));
                    Class<?> _mapperClass = Class.forName(_mapperClassName);
                    if (_mapperClass != null) {
                        String _methodName = _orgMethodId.substring(_orgMethodId.lastIndexOf(".") + 1);
                        _orgMethod = Arrays.stream(_mapperClass.getMethods()).filter(method -> method.getName().equalsIgnoreCase(_methodName)).findFirst().orElse(null);
                        if (_orgMethod != null) {
                            methodCache.put(_orgMethodId, _orgMethod);
                        }
                    }
                }
                if (_orgMethod != null && _orgMethod.isAnnotationPresent(DataSecurity.class)) {
                    _orgMethodDataSecurity = _orgMethod.getAnnotation(DataSecurity.class);
                }
                switch (_ms.getSqlCommandType()) {
                    case INSERT: // 处理插入逻辑
                    case UPDATE: // 处理更新逻辑
                        // 是否为批量插入,真表示是,反之不是。
//                        boolean _isBatchInsert = _parameter instanceof Map;
                        String methodName=_orgMethod.getName();
                        boolean _isBatchInsert =methodName.toUpperCase().contains("BATCH");
                        // log.info("是否为批量插入:" + _isBatchInsert);
                        if (_isBatchInsert) {
                            Map<String, Object> _tempMap = (Map<String, Object>) _parameter;
                            String arrayKey="list";
                            try{
                                if(_tempMap.get("collection")!=null){
                                    arrayKey="collection" ;
                                }
                            }catch (Exception e){
                                log.error(e.getMessage());
                            }
                            List<?> _list = (List<?>) Optional.ofNullable(_tempMap.get(arrayKey)).orElse(new ArrayList<Map<String, Object>>());
                            if (CollectionUtil.isNotEmpty(_list)) {
                                // 克隆新对象,防止影响原始数据
                                _list = BeanUtils.copyToList(_list, _list.get(0).getClass());
                                for (Object _entity : _list) {
                                    this.encryptField(_entity, _orgMethodDataSecurity);
                                }
                            }
                            for (Map.Entry<String, Object> entry : _tempMap.entrySet()) {
                                if (entry.getValue() instanceof ArrayList) {
                                    // 更新原来的集合
                                    _tempMap.put(entry.getKey(), _list);
                                }
                            }
                        } else {
                            // 克隆新对象,防止影响原始数据
                            _parameter = BeanUtil.copyProperties(_parameter, _parameter.getClass());
                            _args[1] = _parameter;
                            this.encryptField(_parameter, _orgMethodDataSecurity);
                        }
                        break;
                    case SELECT: // 处理查询逻辑
                        if (_orgMethod != null) {
                            // 克隆新对象,防止影响原始数据
                            _parameter = this.cloneObject(_parameter);
                            this.encryptField(_parameter, _orgMethodDataSecurity);
                            _args[1] = _parameter;
                        }
                        Object _queryResult = invocation.proceed();
                        if (_orgMethod != null) {
                            // 解密,防止影响下一次的查询
                            this.decryptField(_parameter, _orgMethodDataSecurity);
                            _args[1] = _parameter;
                        }
                        boolean _isList = _queryResult instanceof ArrayList;
                        // log.info("查询结果是不是集合:" + _isList);
                        if (_isList) {
                            List<?> _list = (List<?>) _queryResult;
                            for (Object _entity : _list) {
                                this.decryptField(_entity, _orgMethodDataSecurity);
                            }
                        } else {
                            // 经测试,发现查询单条数据,也是返回集合,故这里先不进行处理。
                        }
                        return _queryResult;
                }
            }
        }
        return invocation.proceed();
    }

    /**
     * 克隆对象。
     *
     * @param source 源对象
     * @return 返回克隆后的对象。
     */
    private Object cloneObject(Object source) {
        Object _result = source;
        if (_result != null) {
            try {
                Class<?> _sourceClazz = _result.getClass();
                if (!_sourceClazz.getSimpleName().endsWith("Example") && this.excludedDataTypeList.indexOf(_sourceClazz) == -1) {
                    // 不需要排除的数据类型,则进行克隆
                    if (_sourceClazz == ArrayList.class || _sourceClazz == List.class) {
                        List<?> _list = (List<?>) Optional.of(_result).orElse(new ArrayList<>());
                        if (CollectionUtil.isNotEmpty(_list)) {
                            _result = BeanUtils.copyToList(_list, _list.get(0).getClass());
                        }
                    } else {
                        _result = BeanUtil.copyProperties(_result, _sourceClazz);
                    }
                }
            } catch (Exception ex) {
            }
        }
        return _result;
    }

    /**
     * 给字段进行加密。
     *
     * @param entity             待加密的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     */
    private void encryptField(Object entity, DataSecurity methodDataSecurity) {
        this.updateField(entity, methodDataSecurity, 0);
    }

    /**
     * 给字段进行解密。
     *
     * @param entity             待解密的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     */
    private void decryptField(Object entity, DataSecurity methodDataSecurity) {
        this.updateField(entity, methodDataSecurity, 1);
    }

    /**
     * 更新字段属性值。
     *
     * @param entity             待更新的实体对象
     * @param methodDataSecurity 方法上被注解的数据安全对象
     * @param type               更新类型,0:表示加密,1:表示解密
     */
    private void updateField(Object entity, DataSecurity methodDataSecurity, int type) {
        if (entity == null) {
            return;
        }
        if (entity instanceof Map) {
            Map<String, Object> _tempMap = (Map<String, Object>) entity;
            Set<Map.Entry<String, Object>> _entrys = _tempMap.entrySet();
            List<String> _tempDataSecurityValueList = new ArrayList<>();
            if (methodDataSecurity != null) {
                _tempDataSecurityValueList = Arrays.asList(methodDataSecurity.value());
            }
            for (Map.Entry<String, Object> _entry : _entrys) {
                if (_entry.getKey().equalsIgnoreCase(CRITERIA_KEY_NAME)) {
                    this.handlerCriteria(_entry.getValue(), methodDataSecurity, type);
                } else {
                    if (_tempDataSecurityValueList.contains(_entry.getKey())) {
                        _tempMap.put(_entry.getKey(), this.formatValue(methodDataSecurity, _entry.getValue().toString(), type));
                    } else {
                        this.updateField(_entry.getValue(), methodDataSecurity, type);
                    }
                }
            }
            return;
        }
        if (entity instanceof ArrayList) {
            List<?> _list = (List<?>) entity;
            for (Object _item : _list) {
                this.updateField(_item, methodDataSecurity, type);
            }
            return;
        }
        Class<?> _clazz = entity.getClass();
        Field[] _fields = _clazz.getDeclaredFields();
        for (Field _field : _fields) {
            _field.setAccessible(true);
            if (methodDataSecurity != null && _field.getName().equalsIgnoreCase(CRITERIA_KEY_NAME)) {
                // 针对 Example 的查询方式进行特殊处理
                try {
                    this.handlerCriteria(_field.get(entity), methodDataSecurity, type);
                } catch (IllegalAccessException e) {
                    // throw new RuntimeException(e);
                }
            } else if (_field.isAnnotationPresent(DataSecurity.class)) {
                // 获取要加密的字段值
                DataSecurity _mDataSecurity = _field.getAnnotation(DataSecurity.class);
                try {
                    Object _value = _field.get(entity);
                    if (_value != null) {
                        Object _tempValue = this.formatValue(_mDataSecurity, _value.toString(), type);
                        if (_tempValue != null) {
                            // 将加解密后的数据设置回去
                            _field.set(entity, _tempValue);
                        }
                    }
                } catch (IllegalAccessException e) {
                    // throw new RuntimeException(e);
                }
            }
        }
    }

    /**
     * 处理 Example 查询数据。
     *
     * @param entity       待更新的实体对象
     * @param dataSecurity 方法上被注解的数据安全对象
     * @param type         更新类型,0:表示加密,1:表示解密
     */
    private void handlerCriteria(Object entity, DataSecurity dataSecurity, int type) {
        if (entity == null || dataSecurity == null) {
            return;
        }
        // 针对 Example 的查询方式进行特殊处理
        List<?> _list = (List<?>) entity;
        // log.info("{}", _list);
        List<String> _tempDataSecurityValueList = Arrays.asList(dataSecurity.value());
        if (CollectionUtil.isNotEmpty(_list) && CollectionUtil.isNotEmpty(_tempDataSecurityValueList)) {
            _list = ReflectUtil.invoke(_list.get(0), CRITERIA_GET_KEY_NAME);
            for (Object _item : _list) {
                try {
                    String _tempCondition = ReflectUtil.invoke(_item, CRITERIA_GET_CONDITION_KEY_NAME);
                    // log.info("{}", _tempCondition);
                    String _keyString = " =";
                    if (!_tempCondition.trim().endsWith(_keyString)) {
                        continue;
                    }
                    _tempCondition = StrUtil.toCamelCase(_tempCondition.substring(0, _tempCondition.length() - _keyString.length()));
                    if (!_tempDataSecurityValueList.contains(_tempCondition)) {
                        continue;
                    }
                    Object _tempValue = ReflectUtil.invoke(_item, CRITERIA_GET_VALUE_KEY_NAME);
                    if (!(_tempValue instanceof String)) {
                        continue;
                    }
                    // log.info("{} {}", _tempCondition, _tempValue);
                    _tempValue = this.formatValue(dataSecurity, _tempValue.toString(), type);
                    if (_tempValue != null) {
                        try {
                            Field _tempField = ReflectUtil.getField(_item.getClass(), CRITERIA_VALUE_KEY_NAME);
                            _tempField.setAccessible(true);
                            _tempField.set(_item, _tempValue);
                        } catch (IllegalAccessException e) {
                            // throw new RuntimeException(e);
                        }
                    }
                    // log.info("{} {}", _tempCondition, _tempValue);
                } catch (Exception ex) {
                }
            }
        }
    }

    /**
     * 格式化值。
     *
     * @param dataSecurity 数据安全对象
     * @param orgValue     原数据值
     * @param type         更新类型,0:表示加密,1:表示解密
     * @return 返回格式化后的值。
     */
    private String formatValue(DataSecurity dataSecurity, String orgValue, int type) {
        if (dataSecurity == null || orgValue == null) {
            return orgValue;
        }
        String _tempValue = null;
        try {
            // 使用加解密算法进行加解密
            switch (type) {
                case 0: // 加密
                    switch (dataSecurity.algorithm()) {
                        case BASE64:
                            _tempValue = Base64Utils.encodeToString(orgValue.getBytes(StandardCharsets.UTF_8));
                            break;
                        case SM2:
                            _tempValue = this.mApiEncryptService.encrypt2Data(orgValue);
                            break;
                        case SM4:
                            _tempValue = this.mApiEncryptService.encrypt4Data(orgValue);
                            break;
                    }
                    break;
                case 1: // 解密
                    switch (dataSecurity.algorithm()) {
                        case BASE64:
                            _tempValue = new String(Base64Utils.decodeFromString(orgValue), StandardCharsets.UTF_8);
                            break;
                        case SM2:
                            _tempValue = this.mApiEncryptService.decrypt2Data(orgValue);
                            break;
                        case SM4:
                            _tempValue = this.mApiEncryptService.decrypt4Data(orgValue);
                            break;
                    }
                    break;
            }
        } catch (Exception ex) {
        }
        if (_tempValue != null) {
            return _tempValue;
        }
        return orgValue;
    }
}

package com.annotation;

import com.enums.DataSecurityAlgorithmEnum;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 数据安全注解类。<br/>
 * 使用 @DataSecurity 注解标记的字段,表示这个字段需要进行加密和解密。
 */
// @Target 设置注解的使用范围,这里设置可以用在方法和字段上。
@Target({ElementType.METHOD, ElementType.FIELD})
// @Target 设置注解的使用范围,这里设置可以用在参数、方法和字段上。
// @Target({ElementType.PARAMETER, ElementType.METHOD, ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME) // 表示该注解保留到运行时。在运行时保留注解信息,可以通过反射获取注解信息。
public @interface DataSecurity {
    /**
     * 加密算法。
     *
     * @return
     */
    DataSecurityAlgorithmEnum algorithm() default DataSecurityAlgorithmEnum.SM4;

    /**
     * 待加解密的属性名称(当注解在方法上时有效)。<br/>
     * 示例:@DataSecurity(value = "name") 或 @DataSecurity(value = {"name1", "name2"})
     *
     * @return
     */
    String[] value() default "";
}

 下面这文章为ParameterHandler参数拦截加密和ResultSetHandler结果解密分开实现

Mybatis拦截器优雅的实现敏感数据的加解密_***sensitive data replaced***-CSDN博客
 

在SpringBoot项目中,自定义注解+拦截器优雅的实现敏感数据的加解密!-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值