Mybatis 敏感数据加密

为增加数据存储的安全性,避免MySQL数据库被入侵导致拖库等情况,动态代理Mybatis Mapper类,使用对称加密算法AES对敏感数据进行加解密操作

亦可选用MybatisBaseTypeHandler实现

实现思路

  1. Spring自动注入配置
  2. 实现BeanPostProcessor接口,初始化Bean时对@EncryptMapper注解修饰类添加动态代理
  3. 动态代理类EncryptMapperProxy@Encrypt注解修饰的字段,进行写入加解密操作
  4. 使用AES算法对数据进行对称加密。在写入时执行加密操作,读取时执行解密操作

base64加密是为了兼容新旧数据切换

实现代码

Bean注册

Spring自动注入配置

/**
 * 注册 EncryptMapperBeanPostProcessor
 */
@Configuration
@ConditionalOnClass(MapperFactoryBean.class)
@AutoConfigureBefore(name="org.mybatis.spring.boot.autoconfigure.MybatisAutoConfiguration")
@Order(Ordered.HIGHEST_PRECEDENCE)
public class MybatisEncryptAutoConfiguration {
    private final static Logger LOGGER = LoggerFactory.getLogger(MybatisEncryptAutoConfiguration.class);
    @Bean
    public EncryptMapperBeanPostProcessor getEncryptMapperBeanPostProcessor(){
        LOGGER.info("EncryptMapperBeanPostProcessor create");
        return new EncryptMapperBeanPostProcessor();
    }
}

注入配置文件:resources\META-INF\spring.provides

org.springframework.boot.autoconfigure.EnableAutoConfiguration=com.example.mybatis.config.MybatisEncryptAutoConfiguration
添加动态代理

实现BeanPostProcessor接口,初始化Bean时对@EncryptMapper注解修饰类添加动态代理

public class EncryptMapperBeanPostProcessor implements BeanPostProcessor, EmbeddedValueResolverAware {

    private final static Logger LOGGER = LoggerFactory.getLogger(EncryptMapperBeanPostProcessor.class);

    private Set<Class<?>> mapperClasps = new HashSet<>();

    private StringValueResolver resolver;

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        return bean;
    }

    private static boolean isEncryptMapper(Class<?> mapperClass){
        if(mapperClass.getAnnotation(EncryptMapper.class) != null){
            return true;
        }
        for(Class<?> t : mapperClass.getInterfaces()){
           return isEncryptMapper(t);
        }
        return false;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if(bean instanceof MapperFactoryBean){
            MapperFactoryBean mapperFactoryBean = (MapperFactoryBean) bean;
            Class<?> mapperInterface = mapperFactoryBean.getMapperInterface();
            if(isEncryptMapper(mapperInterface)){
                mapperClasps.add(mapperInterface);
                LOGGER.info("find MapperFactoryBean beanName={} mapperInterface={}",beanName,mapperInterface);
            }else{
                LOGGER.info("ignore MapperFactoryBean beanName={} mapperInterface={}",beanName,mapperInterface);
            }
        }else{
            for(Class<?> mapperClass : mapperClasps){
                if(bean.getClass().equals(mapperClass)){
                    LOGGER.info("proxy Mapper beanName={} class={} ",beanName,mapperClass);
                    return createMapperProxy(bean, mapperClass);
                }else{
                    for(Class<?> interfaceClass : bean.getClass().getInterfaces()){
                        if(interfaceClass.equals(mapperClass)){
                            LOGGER.info("proxy Mapper beanName={} class={} ",beanName,mapperClass);
                            return createMapperProxy(bean, mapperClass);
                        }
                    }
                }
            }
        }
        return bean;
    }

    private Object createMapperProxy(Object bean, Class<?> mapperClass) {
        return Proxy.newProxyInstance(mapperClass.getClassLoader(),new Class[]{mapperClass},new EncryptMapperProxy<>(mapperClass,bean,resolver));
    }

    @Override
    public void setEmbeddedValueResolver(StringValueResolver resolver) {
        this.resolver = resolver;
    }
}

BeanPostProcessor也称为Bean后置处理器,它是Spring中定义的接口,在Spring容器的创建过程中(具体为Bean初始化前后)会回调BeanPostProcessor中定义的两个方法。
其中postProcessBeforeInitialization方法会在每一个bean对象的初始化方法调用之前回调;postProcessAfterInitialization方法会在每个bean对象的初始化方法调用之后被回调。

动态代理加解密

动态代理类EncryptMapperProxy@Encrypt注解修饰的字段,进行写入加解密操作。使用AES算法对数据进行对称加密。在写入时执行加密操作,读取时执行解密操作

public class EncryptMapperProxy<T> implements InvocationHandler, Serializable {
    private final static Logger LOGGER = LoggerFactory.getLogger(EncryptMapperProxy.class);
    private static ConcurrentHashMap<Class, List<Field>> classAndEncryptFields = new ConcurrentHashMap<>();
    private static ConcurrentHashMap<Method, int[]> methodAndEncryptParameterIndexes = new ConcurrentHashMap<>();
    private static ConcurrentHashMap<Method, Map<Integer, Encrypt>> methodAndParameterEncrypts = new ConcurrentHashMap<>();

    private Class<T> mapperClass;
    private Object mapper;
    private StringValueResolver resolver;

    public EncryptMapperProxy(Class<T> mapperClass, Object mapper, StringValueResolver resolver) {
        this.mapperClass = mapperClass;
        this.mapper = mapper;
        this.resolver = resolver;
    }

    public Class<T> getMapperClass() {
        return mapperClass;
    }

    private int[] getEncryptParameterIndex(Method method) {
        if (!methodAndEncryptParameterIndexes.containsKey(method)) {
            int index = 0;
            List<Integer> indexes = new ArrayList<>();
            int parameterLength = method.getParameters().length;
            for (Parameter p : method.getParameters()) {
                if (!isNotSupportType(p.getType())) {
                    if (p.getType().equals(String.class)) {
                        Encrypt encrypt = p.getAnnotation(Encrypt.class);
                        if (encrypt != null) {
                            methodAndParameterEncrypts.putIfAbsent(method, new HashMap<>(Math.max(1, parameterLength / 2)));
                            methodAndParameterEncrypts.get(method).put(index, encrypt);
                            indexes.add(index);
                        }
                    } else {
                        indexes.add(index);
                    }
                }
                index++;
            }
            int[] result = new int[indexes.size()];
            index = 0;
            for (int i : indexes) {
                result[index++] = i;
            }
            methodAndEncryptParameterIndexes.putIfAbsent(method, result);
            return result;
        }
        return methodAndEncryptParameterIndexes.get(method);
    }

    private String genNewThreadName(Method method) {
        return Thread.currentThread().getName() + "-(" + mapperClass.getSimpleName() + "." + method.getName() + "(...)" + ")";
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        String threadName = Thread.currentThread().getName();
        Thread.currentThread().setName(genNewThreadName(method));
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("invoke proxy={} method={} args={}", proxy, method, args);
        }
        Map<Object, Map<Field, String>> resetArgFieldValue = Collections.emptyMap();
        try {
            if (args != null && args.length > 0) {
                resetArgFieldValue = new HashMap<>(args.length);
                int[] encryptParameterIndexes = getEncryptParameterIndex(method);
                for (int pIndex : encryptParameterIndexes) {
                    Object arg = args[pIndex];
                    if (arg != null) {
                        if (arg instanceof String) {
                            args[pIndex] = encryption(methodAndParameterEncrypts.get(method).get(pIndex), (String) arg);
                        } else {
                            //修改过的字段,在最后还要修改回去
                            Map<Field, String> fieldAndOldValues = handlerFields(arg, false, true);
                            if (!fieldAndOldValues.isEmpty()) {
                                resetArgFieldValue.put(arg, fieldAndOldValues);
                            }
                        }
                    }
                }
            }

            Object result = method.invoke(mapper, args);
            if ((method.getReturnType().equals(Void.TYPE)) || result == null) {
                return result;
            }
            return encryptResult(method, result);
        } finally {
            try {
                resetArgFieldValue(resetArgFieldValue);
            } finally {
                Thread.currentThread().setName(threadName);
            }
        }
    }

    private void resetArgFieldValue(Map<Object, Map<Field, String>> resetArgFieldValue) throws IllegalAccessException {
        if (!resetArgFieldValue.isEmpty()) {
            for (Object object : resetArgFieldValue.keySet()) {
                Map<Field, String> fieldAndOldValues = resetArgFieldValue.get(object);
                if (!fieldAndOldValues.isEmpty()) {
                    for (Field field : fieldAndOldValues.keySet()) {
                        field.setAccessible(true);
                        field.set(object, fieldAndOldValues.get(field));
                        LOGGER.debug("reset class={} field={} value={}", object.getClass().getName(), field.getName(), fieldAndOldValues.get(field));
                    }
                }
            }
        }
    }


    private Object encryptResult(Method method, Object object) {
        Class<?> clazz = method.getReturnType();
        if (clazz.equals(String.class)) {
            Encrypt encrypt = method.getAnnotation(Encrypt.class);
            if (encrypt != null) {
                String value = (String) object;
                return decryption(encrypt, value);
            }
        } else if (object instanceof Collection) {
            return decipherResultCollectionField(method, (Collection) object);
        } else {
            return decipherResultObjectField(method, object);
        }
        return object;
    }

    private String decryption(Encrypt encrypt, String value) {
        //如果是base64就表示加密了,因为要兼容新旧切换
        if (Base64.isBase64(value)) {
            String secretKey = resolver.resolveStringValue(encrypt.secretKey());
            try {
                return AESCryptos.aesDecrypt(secretKey, value);
            } catch (RuntimeException e) {
                LOGGER.warn("decryption fail value={}", value, e);
            }
        }
        return value;
    }

    private boolean isNotSupportType(Object object) {
        return object.getClass().isPrimitive() || object.getClass().equals(Boolean.class) || object instanceof Blob || object instanceof Clob || object instanceof byte[] || object instanceof Collection || object instanceof Map;
    }

    private String encryption(Encrypt encrypt, String plaintext) {
        if(plaintext != null && !"".equalsIgnoreCase(plaintext)){
            String secretKey = resolver.resolveStringValue(encrypt.secretKey());
            return AESCryptos.aesEncryptAndBase64(secretKey, plaintext);
        }
        return plaintext;
    }


    private Collection decipherResultCollectionField(Method method, Collection collection) {
        if (collection == null || collection.isEmpty()) {
            return collection;
        }
        Class typeArgClass = getReturnCollectionGenericType(method);
        LOGGER.debug("collection genericType:{}", typeArgClass);
        if (typeArgClass != null && typeArgClass.equals(String.class)) {
            Encrypt encrypt = method.getAnnotation(Encrypt.class);
            if (encrypt == null) {
                return collection;
            }
            Collection<String> result;
            if (collection instanceof List) {
                result = new ArrayList<>(collection.size());
            } else {
                result = new HashSet<>(collection.size());
            }
            for (Object obj : collection) {
                result.add(decryption(encrypt, (String) obj));
            }
            return result;
        } else {
            for (Object obj : collection) {
                decipherResultObjectField(method, obj);
            }
        }
        return collection;
    }

    private Class getReturnCollectionGenericType(Method method) {
        Type returnType = method.getGenericReturnType();
        Class typeArgClass = null;
        if (returnType instanceof ParameterizedType) {
            ParameterizedType type = (ParameterizedType) returnType;
            Type[] typeArguments = type.getActualTypeArguments();
            if (typeArguments[0] instanceof Class) {
                typeArgClass = (Class) typeArguments[0];
            }
        }
        return typeArgClass;
    }

    /**
     * 加密普通类型的数据
     *
     * @param object
     */
    private Object decipherResultObjectField(Method method, Object object) {
        if (isNotSupportType(object)) {
            return object;
        }
        if (object instanceof String) {
            return object;
        }
        handlerFields(object, true);
        return object;
    }

    private Map<Field, String> handlerFields(Object object, boolean isDecryption) {
        return this.handlerFields(object, isDecryption, false);
    }

    /**
     * @param object
     * @param isDecryption
     * @return 修改过的字段的旧值
     */
    private Map<Field, String> handlerFields(Object object, boolean isDecryption, boolean needOldValue) {
        if (isNotSupportType(object) || object instanceof String) {
            return Collections.emptyMap();
        }
        List<Field> encryptFields = extractEncryptFields(object.getClass());
        if (encryptFields != null && !encryptFields.isEmpty()) {
            Map<Field, String> result = needOldValue ? new HashMap<>(encryptFields.size()) : Collections.emptyMap();
            for (Field field : encryptFields) {
                Encrypt encrypt = field.getAnnotation(Encrypt.class);
                try {
                    field.setAccessible(true);
                    String oldValue = (String) field.get(object);
                    String newValue = oldValue;
                    if (isDecryption) {
                        newValue = decryption(encrypt, oldValue);
                    } else {
                        newValue = encryption(encrypt, oldValue);
                    }
                    if (needOldValue) {
                        result.put(field, oldValue);
                    }
                    field.set(object, newValue);
                    LOGGER.debug("{} value object={} filed={} oldValue={} newValue={} ", (isDecryption ? "decryption" : "encryption"), object.getClass().getName(), field.getName(), oldValue, newValue);
                } catch (IllegalAccessException e) {
                    if (isDecryption) {
                        LOGGER.error("decryption value fail object={} field={}", object, field.getName(), e);
                    } else {
                        LOGGER.error("encryption value fail object={} field={}", object, field.getName(), e);
                    }
                }
            }
            return result;
        }
        return Collections.emptyMap();
    }

    private List<Field> extractEncryptFields(Class<?> clazz) {
        if (!classAndEncryptFields.containsKey(clazz)) {
            Field[] fields = clazz.getDeclaredFields();
            Encrypt encrypt;
            List<Field> encryptFields = new ArrayList<>(fields.length);
            for (Field field : fields) {
                encrypt = field.getAnnotation(Encrypt.class);
                if (encrypt != null) {
                    if (field.getType().equals(String.class)) {
                        encryptFields.add(field);
                        LOGGER.debug("add class={} encrypt field={}", clazz.getName(), field.getName());
                    } else {
                        LOGGER.info("class={} field={} not String type skip encrypt:", clazz.getName(), field.getName());
                    }
                }
            }
            classAndEncryptFields.putIfAbsent(clazz, encryptFields.size() > 0 ? encryptFields : Collections.emptyList());
        }
        return classAndEncryptFields.get(clazz);
    }
}

AES加密算法,会让数据长度变成原来的3~4倍,如果长度不够则需要扩展。

使用方法

sprig boot 配置application.yml

mybatis:
  encrypt:
    secret-key:
      default: xxx #base64后的密钥
      student-phone: yyy #base64后的密钥

实例代码

/**
 * 标示此Mapper开启加密
 */
@EncryptMapper
public  interface TestMapper{
   public @Encrypt List<String> findNames();// 结果会做解码操作
   public List<Student> findByName(@Encrypt String name);//name 加密查询,结果 Student.name属性会做解码操作
   public void insert(Student student);//会扫描类里面使用 @Encrypt 然后进行加密
   public void insert(@Encrypt String name,Date createTime);// name属性加密插入数据库
}

public class Student{
   public int id;
   //标记为需要加密字段
   @Encrypt
   public String name;
   @Encrypt(secretKey = "{mybatis.encrypt.secret-key.student-phone}")
   public String phone;
   public Date createTime;
}

sql查看

字段中均为加密数据

SELECT id, result, AES_DECRYPT(FROM_BASE64(result),FROM_BASE64('xxx'))
FROM encrypt_test

如果字段中存在加密和未加密的数据

SELECT id, result,
    IF(result REGEXP '^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$',
         AES_DECRYPT(FROM_BASE64(result),FROM_BASE64('xxx')),
        result)
FROM encrypt_test

函数: AES_DECRYPT(FROM_BASE64(加密字段),FROM_BASE64('密钥'))

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值