mybatis拦截mapper,增强数据到数据库的读写

一.背景

        最近项目上需要对数据库中某张表的特定字段做增强。具体就是写入数据库时要拦截加密处理,读取时也需要拦截解密。因为涉及到改动的地点特别多,考虑在mapper层统一处理。mapper层只有接口,没有实现类。

        一开始是打算使用AOP切面,也很符合切面的使用场景。但无奈mapper层需要使用jdk代理。而之前的配置文件已定义全局使用cglib,涉及稳定性,就没有去改这个配置。采用了迂回的方式,想到了mybatis自带的Interceptor来实现。

二.相关参考

        mybatis拦截器处理敏感字段_alleged的博客-CSDN博客_mybatis 拦截参数

        mybatis运行时拦截ParameterHandler注入参数 - Lius` - 博客园

        mybatis(3)—自定义拦截器(上)基础使用 - 简书

        感谢以上各位的博客,给了我启发,最终完成需求。

三.具体实现

1.编写Interceptor类

import org.apache.commons.collections.CollectionUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.cglib.proxy.Proxy;
import org.springframework.stereotype.Component;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;

//query拦截查询
//update拦截插入和更新
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
@Component
public class MapperInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        if (!isEnDecryptMapper(mappedStatement) || !isEnDecryptMapperMethod(mappedStatement)) return invocation.proceed();
        System.out.println("Intercept method: " + mappedStatement.getId());
        //获取操作类型,crud
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        if (SqlCommandType.SELECT.equals(sqlCommandType)) return enhanceDecryptBySelect(invocation.proceed());
        if (SqlCommandType.INSERT.equals(sqlCommandType) || SqlCommandType.UPDATE.equals(sqlCommandType))
            enhanceEncryptByInsertOrUpdate(invocation);
        return invocation.proceed();
    }

    /**
     * 判断是否为需要加解密的mapper(@EnDecryptMapperAnnotation)
     */
    private boolean isEnDecryptMapper(MappedStatement mappedStatement) {
        try {
            String namespace = mappedStatement.getId();
            String className = namespace.substring(0, namespace.lastIndexOf("."));
            Class<?> clazz = Class.forName(className);
            Annotation annotation = clazz.getAnnotation(EnDecryptMapperAnnotation.class);
            if (null != annotation) return true;
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
            return false;
        }
        return false;
    }


    /**
     * 判断是否为需要加解密的mapper,且方法上有@EnDecryptMapperMethod
     * 不建议方法重载
     */
    private boolean isEnDecryptMapperMethod(MappedStatement mappedStatement) {
        Method method = getMapperTargetMethod(mappedStatement);
        assert method != null;
        Annotation annotation = method.getAnnotation(EnDecryptMapperMethod.class);
        return null != annotation;
    }

    private Object enhanceDecryptBySelect(Object returnValue) {
        if (returnValue != null) {
            if (returnValue instanceof ArrayList<?>) {
                List<?> oriList = (ArrayList<?>) returnValue;
                List<Object> newList = new ArrayList<>();
                if (CollectionUtils.isNotEmpty(oriList)) {
                    for (Object object : oriList) {
                        EnDecryptPojoUtils.decrypt(object);
                        newList.add(object);
                    }
                    returnValue = newList;
                }
            } else if (returnValue instanceof Map) {
                return returnValue;
            } else {
                EnDecryptPojoUtils.decrypt(returnValue);
            }
        }
        return returnValue;
    }

    private void enhanceEncryptByInsertOrUpdate(Invocation invocation) {
        Object parameter = invocation.getArgs()[1];
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        if (parameter instanceof String) {
            if (isEncryptStr(mappedStatement)) {
                System.out.println("" + parameter);
                //自定义处理
            }
        } else if (parameter instanceof Map) {
            Parameter[] params = getParams(mappedStatement);
            if (null != params && params.length == 1) {
                //只处理一个List参数的方法
                String paramName = params[0].getName();
                try {
                    Map<String, Object> oriMap = convertObjectToMap(parameter);
                    if (oriMap.get(paramName) instanceof ArrayList) {
                        List<?> oriList = (ArrayList<?>) oriMap.get(paramName);
                        List<Object> newList = new ArrayList<>();
                        if (CollectionUtils.isNotEmpty(oriList)) {
                            for (Object object : oriList) {
                                EnDecryptPojoUtils.encrypt(object);
                                newList.add(object);
                            }
                            oriMap.put(paramName, newList);
                            parameter = oriMap;
                        }
                    }
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        } else {
            EnDecryptPojoUtils.encrypt(parameter);
        }
        invocation.getArgs()[1] = parameter;
    }

    /**
     * 获取方法的所有参数
     */
    private Parameter[] getParams(MappedStatement statement) {
        Method method = getMapperTargetMethod(statement);
        assert method != null;
        return method.getParameters();
    }

    /**
     * 判断字符串是否需要加密
     */
    private boolean isEncryptStr(MappedStatement mappedStatement) {
        boolean result = false;
        try {
            Method method = getMapperTargetMethod(mappedStatement);
            assert method != null;
            method.setAccessible(true);
            Annotation[][] parameterAnnotations = method.getParameterAnnotations();
            if (parameterAnnotations.length > 0) {
                for (Annotation[] parameterAnnotation : parameterAnnotations) {
                    for (Annotation annotation : parameterAnnotation) {
                        if (annotation instanceof EncryptField) {
                            result = true;
                            break;
                        }
                    }
                }
            }
        } catch (SecurityException e) {
            e.printStackTrace();
            result = false;
        }
        return result;
    }

    /**
     * 获取mapper层接口的方法
     */
    private Method getMapperTargetMethod(MappedStatement mappedStatement) {
        Method method = null;
        try {
            String namespace = mappedStatement.getId();
            String className = namespace.substring(0, namespace.lastIndexOf("."));
            String methodName = namespace.substring(namespace.lastIndexOf(".") + 1);
            Method[] ms = Class.forName(className).getMethods();
            for (Method m : ms) {
                if (m.getName().equals(methodName)) {
                    method = m;
                    break;
                }
            }
        } catch (SecurityException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
        return method;
    }

    /**
     * object 转 map
     */
    public Map<String, Object> convertObjectToMap(Object obj) throws IllegalAccessException {
        Map<String, Object> map = new LinkedHashMap<>();
        Class<?> clazz = obj.getClass();
        System.out.println(clazz);
        for (Field field : clazz.getDeclaredFields()) {
            field.setAccessible(true);
            String fieldName = field.getName();
            Object value = field.get(obj);
            if (value == null) {
                value = "";
            }
            map.put(fieldName, value);
        }
        return map;
    }

    @Override
    public Object plugin(Object target) {
        Object real = realTarget(target);
        return Plugin.wrap(real, this);
    }

    @SuppressWarnings("unchecked")
    private <T> T realTarget(Object target) {
        if (Proxy.isProxyClass(target.getClass())) {
            MetaObject metaObject = SystemMetaObject.forObject(target);
            return realTarget(metaObject.getValue("h.target"));
        }
        return (T) target;
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

拦截前,先判断该mapper以及具体的method是否需要拦截处理。只需要在相应的地方加上自定义注解。

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface EnDecryptMapperAnnotation {
}

具体加注解的地方

@EnDecryptMapperAnnotation
@Repository
public interface UserMapper {

  @EnDecryptMapperMethod
  User getUserByUid(String uid);

  List<User> getAll(User user);

  @EnDecryptMapperMethod
  int insertOne(User user);

  @EnDecryptMapperMethod
  int insertBatch(List<User> userList);
}

如果接口和方法拦截通过,mappedStatement.getSqlCommandType()返回该操作是什么类型的CRUD。本次方案,拦截insert/update/select。

判断select,进入

enhanceDecryptBySelect(invocation.proceed())

方法,invocation.proceed()为查询返回结果,可对此结果直接处理后返回。

---------------------------

判断insert/update,进入

enhanceEncryptByInsertOrUpdate(invocation)

方法,

Object parameter = invocation.getArgs()[1];//方法参数
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];

针对获取到的parameter做拦截处理。

其中对入参是list类型时,parameter instanceof Map = true,与预想的为list不同。map中含有多个key,其中有一个key=参数名,通过反射可以获取到参数名,从而得到具体的入参list,做修改。

2.自定义数据处理类


import org.apache.commons.lang3.StringUtils;

import java.lang.reflect.Field;

public class EnDecryptPojoUtils {
    /**
     * 对象t注解字段加密
     */
    public static <T> void encrypt(T t) {
        if (isEncryptAndDecrypt(t)) {
            Field[] declaredFields = t.getClass().getDeclaredFields();
            try {
                if (declaredFields.length > 0) {
                    for (Field field : declaredFields) {
                        if (field.isAnnotationPresent(EncryptField.class) && field.getType().toString().endsWith("String")) {
                            field.setAccessible(true);
                            String fieldValue = (String) field.get(t);
                            if (StringUtils.isNotEmpty(fieldValue)) {
                                field.set(t, fieldValue + "---");//todo encrypt
                            }
                            field.setAccessible(false);
                        }
                    }
                }
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * 对含注解字段解密
     */
    public static <T> void decrypt(T t) {
        if (isEncryptAndDecrypt(t)) {
            Field[] declaredFields = t.getClass().getDeclaredFields();
            try {
                if (declaredFields.length > 0) {
                    for (Field field : declaredFields) {
                        if (field.isAnnotationPresent(DecryptField.class) && field.getType().toString().endsWith("String")) {
                            field.setAccessible(true);
                            String fieldValue = (String) field.get(t);
                            if (StringUtils.isNotEmpty(fieldValue)) {
                                field.set(t, fieldValue.replaceAll("---", ""));//todo decrypt
                            }
                        }
                    }
                }
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * 判断是否需要加密解密的类
     */
    public static <T> Boolean isEncryptAndDecrypt(T t) {
        boolean reslut = false;
        if (t != null) {
            Object object = t.getClass().getAnnotation(EnDecryptMapperEntity.class);
            if (object != null) {
                reslut = true;
            }
        }
        return reslut;
    }
}

框架是springboot 2.5.7版本。

具体代码可参考 https://github.com/sky-91/redis/tree/mybatis-plugin-0323

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MyBatis是一种持久层框架,通过它可以将Java对象映射到数据库表中,从而实现数据的持久化。在MyBatis中,保存数据数据库通常包括以下几个步骤: 1. 配置数据源:在MyBatis中,需要配置数据源,即数据库连接池,可以使用JDBC连接池或第三方数据源,比如Druid、Hikari等。 2. 编写Mapper接口:在MyBatis中,通常使用Mapper接口来定义对数据库的操作,需要编写Mapper接口及其对应的XML文件,包括SQL语句和参数映射等。 3. 创建SqlSession:在MyBatis中,需要创建SqlSession对象,用于执行SQL语句和管理事务等,可以通过SqlSessionFactory创建SqlSession。 4. 执行SQL语句:在MyBatis中,可以通过SqlSession执行SQL语句,包括插入、更新、删除和查询等操作。 5. 提交事务:在MyBatis中,需要手动提交事务,可以调用SqlSession的commit方法来提交事务,也可以在配置文件中设置自动提交事务。 具体保存数据数据库的过程,可以通过以下示例代码进行说明: ``` // 1. 配置数据源 DataSource dataSource = ...; // 2. 编写Mapper接口 public interface UserMapper { void insertUser(User user); void updateUser(User user); void deleteUser(int id); } // 3. 创建SqlSession SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream); SqlSession sqlSession = sqlSessionFactory.openSession(); // 4. 执行SQL语句 UserMapper userMapper = sqlSession.getMapper(UserMapper.class); User user = new User(); user.setId(1); user.setName("Tom"); user.setAge(20); userMapper.insertUser(user); // 5. 提交事务 sqlSession.commit(); ``` 在上面的示例中,首先通过数据源配置了数据库连接池,然后定义了一个UserMapper接口,其中包括了插入、更新和删除等操作。接着创建了SqlSession对象,通过getMapper方法获取了UserMapper的实现类,并调用insertUser方法向数据库中插入了一条用户记录。最后通过commit方法提交事务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值