一.背景
最近项目上需要对数据库中某张表的特定字段做增强。具体就是写入数据库时要拦截加密处理,读取时也需要拦截解密。因为涉及到改动的地点特别多,考虑在mapper层统一处理。mapper层只有接口,没有实现类。
一开始是打算使用AOP切面,也很符合切面的使用场景。但无奈mapper层需要使用jdk代理。而之前的配置文件已定义全局使用cglib,涉及稳定性,就没有去改这个配置。采用了迂回的方式,想到了mybatis自带的Interceptor来实现。
二.相关参考
mybatis拦截器处理敏感字段_alleged的博客-CSDN博客_mybatis 拦截参数
mybatis运行时拦截ParameterHandler注入参数 - Lius` - 博客园
感谢以上各位的博客,给了我启发,最终完成需求。
三.具体实现
1.编写Interceptor类
import org.apache.commons.collections.CollectionUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.cglib.proxy.Proxy;
import org.springframework.stereotype.Component;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;
//query拦截查询
//update拦截插入和更新
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
@Component
public class MapperInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
if (!isEnDecryptMapper(mappedStatement) || !isEnDecryptMapperMethod(mappedStatement)) return invocation.proceed();
System.out.println("Intercept method: " + mappedStatement.getId());
//获取操作类型,crud
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
if (SqlCommandType.SELECT.equals(sqlCommandType)) return enhanceDecryptBySelect(invocation.proceed());
if (SqlCommandType.INSERT.equals(sqlCommandType) || SqlCommandType.UPDATE.equals(sqlCommandType))
enhanceEncryptByInsertOrUpdate(invocation);
return invocation.proceed();
}
/**
* 判断是否为需要加解密的mapper(@EnDecryptMapperAnnotation)
*/
private boolean isEnDecryptMapper(MappedStatement mappedStatement) {
try {
String namespace = mappedStatement.getId();
String className = namespace.substring(0, namespace.lastIndexOf("."));
Class<?> clazz = Class.forName(className);
Annotation annotation = clazz.getAnnotation(EnDecryptMapperAnnotation.class);
if (null != annotation) return true;
} catch (ClassNotFoundException e) {
e.printStackTrace();
return false;
}
return false;
}
/**
* 判断是否为需要加解密的mapper,且方法上有@EnDecryptMapperMethod
* 不建议方法重载
*/
private boolean isEnDecryptMapperMethod(MappedStatement mappedStatement) {
Method method = getMapperTargetMethod(mappedStatement);
assert method != null;
Annotation annotation = method.getAnnotation(EnDecryptMapperMethod.class);
return null != annotation;
}
private Object enhanceDecryptBySelect(Object returnValue) {
if (returnValue != null) {
if (returnValue instanceof ArrayList<?>) {
List<?> oriList = (ArrayList<?>) returnValue;
List<Object> newList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(oriList)) {
for (Object object : oriList) {
EnDecryptPojoUtils.decrypt(object);
newList.add(object);
}
returnValue = newList;
}
} else if (returnValue instanceof Map) {
return returnValue;
} else {
EnDecryptPojoUtils.decrypt(returnValue);
}
}
return returnValue;
}
private void enhanceEncryptByInsertOrUpdate(Invocation invocation) {
Object parameter = invocation.getArgs()[1];
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
if (parameter instanceof String) {
if (isEncryptStr(mappedStatement)) {
System.out.println("" + parameter);
//自定义处理
}
} else if (parameter instanceof Map) {
Parameter[] params = getParams(mappedStatement);
if (null != params && params.length == 1) {
//只处理一个List参数的方法
String paramName = params[0].getName();
try {
Map<String, Object> oriMap = convertObjectToMap(parameter);
if (oriMap.get(paramName) instanceof ArrayList) {
List<?> oriList = (ArrayList<?>) oriMap.get(paramName);
List<Object> newList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(oriList)) {
for (Object object : oriList) {
EnDecryptPojoUtils.encrypt(object);
newList.add(object);
}
oriMap.put(paramName, newList);
parameter = oriMap;
}
}
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
} else {
EnDecryptPojoUtils.encrypt(parameter);
}
invocation.getArgs()[1] = parameter;
}
/**
* 获取方法的所有参数
*/
private Parameter[] getParams(MappedStatement statement) {
Method method = getMapperTargetMethod(statement);
assert method != null;
return method.getParameters();
}
/**
* 判断字符串是否需要加密
*/
private boolean isEncryptStr(MappedStatement mappedStatement) {
boolean result = false;
try {
Method method = getMapperTargetMethod(mappedStatement);
assert method != null;
method.setAccessible(true);
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
if (parameterAnnotations.length > 0) {
for (Annotation[] parameterAnnotation : parameterAnnotations) {
for (Annotation annotation : parameterAnnotation) {
if (annotation instanceof EncryptField) {
result = true;
break;
}
}
}
}
} catch (SecurityException e) {
e.printStackTrace();
result = false;
}
return result;
}
/**
* 获取mapper层接口的方法
*/
private Method getMapperTargetMethod(MappedStatement mappedStatement) {
Method method = null;
try {
String namespace = mappedStatement.getId();
String className = namespace.substring(0, namespace.lastIndexOf("."));
String methodName = namespace.substring(namespace.lastIndexOf(".") + 1);
Method[] ms = Class.forName(className).getMethods();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
method = m;
break;
}
}
} catch (SecurityException | ClassNotFoundException e) {
e.printStackTrace();
return null;
}
return method;
}
/**
* object 转 map
*/
public Map<String, Object> convertObjectToMap(Object obj) throws IllegalAccessException {
Map<String, Object> map = new LinkedHashMap<>();
Class<?> clazz = obj.getClass();
System.out.println(clazz);
for (Field field : clazz.getDeclaredFields()) {
field.setAccessible(true);
String fieldName = field.getName();
Object value = field.get(obj);
if (value == null) {
value = "";
}
map.put(fieldName, value);
}
return map;
}
@Override
public Object plugin(Object target) {
Object real = realTarget(target);
return Plugin.wrap(real, this);
}
@SuppressWarnings("unchecked")
private <T> T realTarget(Object target) {
if (Proxy.isProxyClass(target.getClass())) {
MetaObject metaObject = SystemMetaObject.forObject(target);
return realTarget(metaObject.getValue("h.target"));
}
return (T) target;
}
@Override
public void setProperties(Properties properties) {
}
}
拦截前,先判断该mapper以及具体的method是否需要拦截处理。只需要在相应的地方加上自定义注解。
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface EnDecryptMapperAnnotation {
}
具体加注解的地方
@EnDecryptMapperAnnotation
@Repository
public interface UserMapper {
@EnDecryptMapperMethod
User getUserByUid(String uid);
List<User> getAll(User user);
@EnDecryptMapperMethod
int insertOne(User user);
@EnDecryptMapperMethod
int insertBatch(List<User> userList);
}
如果接口和方法拦截通过,mappedStatement.getSqlCommandType()返回该操作是什么类型的CRUD。本次方案,拦截insert/update/select。
判断select,进入
enhanceDecryptBySelect(invocation.proceed())
方法,invocation.proceed()为查询返回结果,可对此结果直接处理后返回。
---------------------------
判断insert/update,进入
enhanceEncryptByInsertOrUpdate(invocation)
方法,
Object parameter = invocation.getArgs()[1];//方法参数 MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
针对获取到的parameter做拦截处理。
其中对入参是list类型时,parameter instanceof Map = true,与预想的为list不同。map中含有多个key,其中有一个key=参数名,通过反射可以获取到参数名,从而得到具体的入参list,做修改。
2.自定义数据处理类
import org.apache.commons.lang3.StringUtils;
import java.lang.reflect.Field;
public class EnDecryptPojoUtils {
/**
* 对象t注解字段加密
*/
public static <T> void encrypt(T t) {
if (isEncryptAndDecrypt(t)) {
Field[] declaredFields = t.getClass().getDeclaredFields();
try {
if (declaredFields.length > 0) {
for (Field field : declaredFields) {
if (field.isAnnotationPresent(EncryptField.class) && field.getType().toString().endsWith("String")) {
field.setAccessible(true);
String fieldValue = (String) field.get(t);
if (StringUtils.isNotEmpty(fieldValue)) {
field.set(t, fieldValue + "---");//todo encrypt
}
field.setAccessible(false);
}
}
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
/**
* 对含注解字段解密
*/
public static <T> void decrypt(T t) {
if (isEncryptAndDecrypt(t)) {
Field[] declaredFields = t.getClass().getDeclaredFields();
try {
if (declaredFields.length > 0) {
for (Field field : declaredFields) {
if (field.isAnnotationPresent(DecryptField.class) && field.getType().toString().endsWith("String")) {
field.setAccessible(true);
String fieldValue = (String) field.get(t);
if (StringUtils.isNotEmpty(fieldValue)) {
field.set(t, fieldValue.replaceAll("---", ""));//todo decrypt
}
}
}
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
/**
* 判断是否需要加密解密的类
*/
public static <T> Boolean isEncryptAndDecrypt(T t) {
boolean reslut = false;
if (t != null) {
Object object = t.getClass().getAnnotation(EnDecryptMapperEntity.class);
if (object != null) {
reslut = true;
}
}
return reslut;
}
}
框架是springboot 2.5.7版本。
具体代码可参考 https://github.com/sky-91/redis/tree/mybatis-plugin-0323