为增加数据存储的安全性,避免MySQL数据库被入侵导致拖库等情况,动态代理Mybatis Mapper类,使用对称加密算法AES对敏感数据进行加解密操作
亦可选用
Mybatis
的BaseTypeHandler实现
实现思路
- Spring自动注入配置
- 实现
BeanPostProcessor
接口,初始化Bean时对@EncryptMapper
注解修饰类添加动态代理 - 动态代理类
EncryptMapperProxy
对@Encrypt
注解修饰的字段,进行写入加解密操作 - 使用AES算法对数据进行对称加密。在写入时执行加密操作,读取时执行解密操作
base64加密是为了兼容新旧数据切换
实现代码
Bean注册
Spring自动注入配置
/**
* 注册 EncryptMapperBeanPostProcessor
*/
@Configuration
@ConditionalOnClass(MapperFactoryBean.class)
@AutoConfigureBefore(name="org.mybatis.spring.boot.autoconfigure.MybatisAutoConfiguration")
@Order(Ordered.HIGHEST_PRECEDENCE)
public class MybatisEncryptAutoConfiguration {
private final static Logger LOGGER = LoggerFactory.getLogger(MybatisEncryptAutoConfiguration.class);
@Bean
public EncryptMapperBeanPostProcessor getEncryptMapperBeanPostProcessor(){
LOGGER.info("EncryptMapperBeanPostProcessor create");
return new EncryptMapperBeanPostProcessor();
}
}
注入配置文件:resources\META-INF\spring.provides
org.springframework.boot.autoconfigure.EnableAutoConfiguration=com.example.mybatis.config.MybatisEncryptAutoConfiguration
添加动态代理
实现BeanPostProcessor
接口,初始化Bean时对@EncryptMapper
注解修饰类添加动态代理
public class EncryptMapperBeanPostProcessor implements BeanPostProcessor, EmbeddedValueResolverAware {
private final static Logger LOGGER = LoggerFactory.getLogger(EncryptMapperBeanPostProcessor.class);
private Set<Class<?>> mapperClasps = new HashSet<>();
private StringValueResolver resolver;
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
private static boolean isEncryptMapper(Class<?> mapperClass){
if(mapperClass.getAnnotation(EncryptMapper.class) != null){
return true;
}
for(Class<?> t : mapperClass.getInterfaces()){
return isEncryptMapper(t);
}
return false;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if(bean instanceof MapperFactoryBean){
MapperFactoryBean mapperFactoryBean = (MapperFactoryBean) bean;
Class<?> mapperInterface = mapperFactoryBean.getMapperInterface();
if(isEncryptMapper(mapperInterface)){
mapperClasps.add(mapperInterface);
LOGGER.info("find MapperFactoryBean beanName={} mapperInterface={}",beanName,mapperInterface);
}else{
LOGGER.info("ignore MapperFactoryBean beanName={} mapperInterface={}",beanName,mapperInterface);
}
}else{
for(Class<?> mapperClass : mapperClasps){
if(bean.getClass().equals(mapperClass)){
LOGGER.info("proxy Mapper beanName={} class={} ",beanName,mapperClass);
return createMapperProxy(bean, mapperClass);
}else{
for(Class<?> interfaceClass : bean.getClass().getInterfaces()){
if(interfaceClass.equals(mapperClass)){
LOGGER.info("proxy Mapper beanName={} class={} ",beanName,mapperClass);
return createMapperProxy(bean, mapperClass);
}
}
}
}
}
return bean;
}
private Object createMapperProxy(Object bean, Class<?> mapperClass) {
return Proxy.newProxyInstance(mapperClass.getClassLoader(),new Class[]{mapperClass},new EncryptMapperProxy<>(mapperClass,bean,resolver));
}
@Override
public void setEmbeddedValueResolver(StringValueResolver resolver) {
this.resolver = resolver;
}
}
BeanPostProcessor
也称为Bean后置处理器,它是Spring中定义的接口,在Spring容器的创建过程中(具体为Bean初始化前后)会回调BeanPostProcessor
中定义的两个方法。
其中postProcessBeforeInitialization
方法会在每一个bean对象的初始化方法调用之前回调;postProcessAfterInitialization
方法会在每个bean对象的初始化方法调用之后被回调。
动态代理加解密
动态代理类EncryptMapperProxy
对@Encrypt
注解修饰的字段,进行写入加解密操作。使用AES算法对数据进行对称加密。在写入时执行加密操作,读取时执行解密操作
public class EncryptMapperProxy<T> implements InvocationHandler, Serializable {
private final static Logger LOGGER = LoggerFactory.getLogger(EncryptMapperProxy.class);
private static ConcurrentHashMap<Class, List<Field>> classAndEncryptFields = new ConcurrentHashMap<>();
private static ConcurrentHashMap<Method, int[]> methodAndEncryptParameterIndexes = new ConcurrentHashMap<>();
private static ConcurrentHashMap<Method, Map<Integer, Encrypt>> methodAndParameterEncrypts = new ConcurrentHashMap<>();
private Class<T> mapperClass;
private Object mapper;
private StringValueResolver resolver;
public EncryptMapperProxy(Class<T> mapperClass, Object mapper, StringValueResolver resolver) {
this.mapperClass = mapperClass;
this.mapper = mapper;
this.resolver = resolver;
}
public Class<T> getMapperClass() {
return mapperClass;
}
private int[] getEncryptParameterIndex(Method method) {
if (!methodAndEncryptParameterIndexes.containsKey(method)) {
int index = 0;
List<Integer> indexes = new ArrayList<>();
int parameterLength = method.getParameters().length;
for (Parameter p : method.getParameters()) {
if (!isNotSupportType(p.getType())) {
if (p.getType().equals(String.class)) {
Encrypt encrypt = p.getAnnotation(Encrypt.class);
if (encrypt != null) {
methodAndParameterEncrypts.putIfAbsent(method, new HashMap<>(Math.max(1, parameterLength / 2)));
methodAndParameterEncrypts.get(method).put(index, encrypt);
indexes.add(index);
}
} else {
indexes.add(index);
}
}
index++;
}
int[] result = new int[indexes.size()];
index = 0;
for (int i : indexes) {
result[index++] = i;
}
methodAndEncryptParameterIndexes.putIfAbsent(method, result);
return result;
}
return methodAndEncryptParameterIndexes.get(method);
}
private String genNewThreadName(Method method) {
return Thread.currentThread().getName() + "-(" + mapperClass.getSimpleName() + "." + method.getName() + "(...)" + ")";
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String threadName = Thread.currentThread().getName();
Thread.currentThread().setName(genNewThreadName(method));
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("invoke proxy={} method={} args={}", proxy, method, args);
}
Map<Object, Map<Field, String>> resetArgFieldValue = Collections.emptyMap();
try {
if (args != null && args.length > 0) {
resetArgFieldValue = new HashMap<>(args.length);
int[] encryptParameterIndexes = getEncryptParameterIndex(method);
for (int pIndex : encryptParameterIndexes) {
Object arg = args[pIndex];
if (arg != null) {
if (arg instanceof String) {
args[pIndex] = encryption(methodAndParameterEncrypts.get(method).get(pIndex), (String) arg);
} else {
//修改过的字段,在最后还要修改回去
Map<Field, String> fieldAndOldValues = handlerFields(arg, false, true);
if (!fieldAndOldValues.isEmpty()) {
resetArgFieldValue.put(arg, fieldAndOldValues);
}
}
}
}
}
Object result = method.invoke(mapper, args);
if ((method.getReturnType().equals(Void.TYPE)) || result == null) {
return result;
}
return encryptResult(method, result);
} finally {
try {
resetArgFieldValue(resetArgFieldValue);
} finally {
Thread.currentThread().setName(threadName);
}
}
}
private void resetArgFieldValue(Map<Object, Map<Field, String>> resetArgFieldValue) throws IllegalAccessException {
if (!resetArgFieldValue.isEmpty()) {
for (Object object : resetArgFieldValue.keySet()) {
Map<Field, String> fieldAndOldValues = resetArgFieldValue.get(object);
if (!fieldAndOldValues.isEmpty()) {
for (Field field : fieldAndOldValues.keySet()) {
field.setAccessible(true);
field.set(object, fieldAndOldValues.get(field));
LOGGER.debug("reset class={} field={} value={}", object.getClass().getName(), field.getName(), fieldAndOldValues.get(field));
}
}
}
}
}
private Object encryptResult(Method method, Object object) {
Class<?> clazz = method.getReturnType();
if (clazz.equals(String.class)) {
Encrypt encrypt = method.getAnnotation(Encrypt.class);
if (encrypt != null) {
String value = (String) object;
return decryption(encrypt, value);
}
} else if (object instanceof Collection) {
return decipherResultCollectionField(method, (Collection) object);
} else {
return decipherResultObjectField(method, object);
}
return object;
}
private String decryption(Encrypt encrypt, String value) {
//如果是base64就表示加密了,因为要兼容新旧切换
if (Base64.isBase64(value)) {
String secretKey = resolver.resolveStringValue(encrypt.secretKey());
try {
return AESCryptos.aesDecrypt(secretKey, value);
} catch (RuntimeException e) {
LOGGER.warn("decryption fail value={}", value, e);
}
}
return value;
}
private boolean isNotSupportType(Object object) {
return object.getClass().isPrimitive() || object.getClass().equals(Boolean.class) || object instanceof Blob || object instanceof Clob || object instanceof byte[] || object instanceof Collection || object instanceof Map;
}
private String encryption(Encrypt encrypt, String plaintext) {
if(plaintext != null && !"".equalsIgnoreCase(plaintext)){
String secretKey = resolver.resolveStringValue(encrypt.secretKey());
return AESCryptos.aesEncryptAndBase64(secretKey, plaintext);
}
return plaintext;
}
private Collection decipherResultCollectionField(Method method, Collection collection) {
if (collection == null || collection.isEmpty()) {
return collection;
}
Class typeArgClass = getReturnCollectionGenericType(method);
LOGGER.debug("collection genericType:{}", typeArgClass);
if (typeArgClass != null && typeArgClass.equals(String.class)) {
Encrypt encrypt = method.getAnnotation(Encrypt.class);
if (encrypt == null) {
return collection;
}
Collection<String> result;
if (collection instanceof List) {
result = new ArrayList<>(collection.size());
} else {
result = new HashSet<>(collection.size());
}
for (Object obj : collection) {
result.add(decryption(encrypt, (String) obj));
}
return result;
} else {
for (Object obj : collection) {
decipherResultObjectField(method, obj);
}
}
return collection;
}
private Class getReturnCollectionGenericType(Method method) {
Type returnType = method.getGenericReturnType();
Class typeArgClass = null;
if (returnType instanceof ParameterizedType) {
ParameterizedType type = (ParameterizedType) returnType;
Type[] typeArguments = type.getActualTypeArguments();
if (typeArguments[0] instanceof Class) {
typeArgClass = (Class) typeArguments[0];
}
}
return typeArgClass;
}
/**
* 加密普通类型的数据
*
* @param object
*/
private Object decipherResultObjectField(Method method, Object object) {
if (isNotSupportType(object)) {
return object;
}
if (object instanceof String) {
return object;
}
handlerFields(object, true);
return object;
}
private Map<Field, String> handlerFields(Object object, boolean isDecryption) {
return this.handlerFields(object, isDecryption, false);
}
/**
* @param object
* @param isDecryption
* @return 修改过的字段的旧值
*/
private Map<Field, String> handlerFields(Object object, boolean isDecryption, boolean needOldValue) {
if (isNotSupportType(object) || object instanceof String) {
return Collections.emptyMap();
}
List<Field> encryptFields = extractEncryptFields(object.getClass());
if (encryptFields != null && !encryptFields.isEmpty()) {
Map<Field, String> result = needOldValue ? new HashMap<>(encryptFields.size()) : Collections.emptyMap();
for (Field field : encryptFields) {
Encrypt encrypt = field.getAnnotation(Encrypt.class);
try {
field.setAccessible(true);
String oldValue = (String) field.get(object);
String newValue = oldValue;
if (isDecryption) {
newValue = decryption(encrypt, oldValue);
} else {
newValue = encryption(encrypt, oldValue);
}
if (needOldValue) {
result.put(field, oldValue);
}
field.set(object, newValue);
LOGGER.debug("{} value object={} filed={} oldValue={} newValue={} ", (isDecryption ? "decryption" : "encryption"), object.getClass().getName(), field.getName(), oldValue, newValue);
} catch (IllegalAccessException e) {
if (isDecryption) {
LOGGER.error("decryption value fail object={} field={}", object, field.getName(), e);
} else {
LOGGER.error("encryption value fail object={} field={}", object, field.getName(), e);
}
}
}
return result;
}
return Collections.emptyMap();
}
private List<Field> extractEncryptFields(Class<?> clazz) {
if (!classAndEncryptFields.containsKey(clazz)) {
Field[] fields = clazz.getDeclaredFields();
Encrypt encrypt;
List<Field> encryptFields = new ArrayList<>(fields.length);
for (Field field : fields) {
encrypt = field.getAnnotation(Encrypt.class);
if (encrypt != null) {
if (field.getType().equals(String.class)) {
encryptFields.add(field);
LOGGER.debug("add class={} encrypt field={}", clazz.getName(), field.getName());
} else {
LOGGER.info("class={} field={} not String type skip encrypt:", clazz.getName(), field.getName());
}
}
}
classAndEncryptFields.putIfAbsent(clazz, encryptFields.size() > 0 ? encryptFields : Collections.emptyList());
}
return classAndEncryptFields.get(clazz);
}
}
AES加密算法,会让数据长度变成原来的3~4倍,如果长度不够则需要扩展。
使用方法
sprig boot 配置application.yml
mybatis:
encrypt:
secret-key:
default: xxx #base64后的密钥
student-phone: yyy #base64后的密钥
实例代码
/**
* 标示此Mapper开启加密
*/
@EncryptMapper
public interface TestMapper{
public @Encrypt List<String> findNames();// 结果会做解码操作
public List<Student> findByName(@Encrypt String name);//name 加密查询,结果 Student.name属性会做解码操作
public void insert(Student student);//会扫描类里面使用 @Encrypt 然后进行加密
public void insert(@Encrypt String name,Date createTime);// name属性加密插入数据库
}
public class Student{
public int id;
//标记为需要加密字段
@Encrypt
public String name;
@Encrypt(secretKey = "{mybatis.encrypt.secret-key.student-phone}")
public String phone;
public Date createTime;
}
sql查看
字段中均为加密数据
SELECT id, result, AES_DECRYPT(FROM_BASE64(result),FROM_BASE64('xxx'))
FROM encrypt_test
如果字段中存在加密和未加密的数据
SELECT id, result,
IF(result REGEXP '^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$',
AES_DECRYPT(FROM_BASE64(result),FROM_BASE64('xxx')),
result)
FROM encrypt_test
函数:
AES_DECRYPT(FROM_BASE64(加密字段),FROM_BASE64('密钥'))