在有些项目当中,需要多个数据源,甚至是不同类型的数据库,但是想公用MyBatis的接口以及xml资源。此时可以根据数据源动态创建新的SqlSessionFactory实例,而不是在启动过程中创建的单例。对应的代码如下,主要有两点
- 大体逻辑直接从
org.mybatis.spring.boot.autoconfigure.MybatisAutoConfiguration#sqlSessionFactory
中拷贝而来(这样可以共用在配置文件中针对MyBatis的各种配置) - 需要修改
org.mybatis.spring.boot.autoconfigure.MybatisAutoConfiguration#applyConfiguration
方法,因为默认情况下只有一个Configuration对象,但是一旦涉及到数据源切换(尤其是不同类型数据库,涉及到databaseId的问题),必须是不同的Configuration对象。
import org.apache.ibatis.mapping.DatabaseIdProvider;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.scripting.LanguageDriver;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.type.TypeHandler;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.boot.autoconfigure.ConfigurationCustomizer;
import org.mybatis.spring.boot.autoconfigure.MybatisProperties;
import org.mybatis.spring.boot.autoconfigure.SpringBootVFS;
import org.springframework.beans.BeanWrapperImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.beans.PropertyDescriptor;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* @author: guanglai.zhou
* @date: 2021/11/11 17:10
*/
@Component
public class SqlSessionFactoryProvider {
/**
* 更换数据源创建一个新的SqlSession工厂
*
* @param dataSource 新的数据源
* @return 新的MyBatis SqlSession工厂
*/
public SqlSessionFactory newInstance(DataSource dataSource) {
try {
return sqlSessionFactory(dataSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Autowired(required = false)
private MybatisProperties properties;
@Autowired(required = false)
private Interceptor[] interceptors;
@Autowired(required = false)
private TypeHandler[] typeHandlers;
@Autowired(required = false)
private LanguageDriver[] languageDrivers;
@Autowired(required = false)
private ResourceLoader resourceLoader;
@Autowired(required = false)
private DatabaseIdProvider databaseIdProvider;
@Autowired(required = false)
private List<ConfigurationCustomizer> configurationCustomizers;
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
SqlSessionFactoryBean factory = new SqlSessionFactoryBean();
factory.setDataSource(dataSource);
factory.setVfs(SpringBootVFS.class);
if (StringUtils.hasText(this.properties.getConfigLocation())) {
factory.setConfigLocation(this.resourceLoader.getResource(this.properties.getConfigLocation()));
}
applyConfiguration(factory);
if (this.properties.getConfigurationProperties() != null) {
factory.setConfigurationProperties(this.properties.getConfigurationProperties());
}
if (!ObjectUtils.isEmpty(this.interceptors)) {
factory.setPlugins(this.interceptors);
}
if (this.databaseIdProvider != null) {
factory.setDatabaseIdProvider(this.databaseIdProvider);
}
if (StringUtils.hasLength(this.properties.getTypeAliasesPackage())) {
factory.setTypeAliasesPackage(this.properties.getTypeAliasesPackage());
}
if (this.properties.getTypeAliasesSuperType() != null) {
factory.setTypeAliasesSuperType(this.properties.getTypeAliasesSuperType());
}
if (StringUtils.hasLength(this.properties.getTypeHandlersPackage())) {
factory.setTypeHandlersPackage(this.properties.getTypeHandlersPackage());
}
if (!ObjectUtils.isEmpty(this.typeHandlers)) {
factory.setTypeHandlers(this.typeHandlers);
}
if (!ObjectUtils.isEmpty(this.properties.resolveMapperLocations())) {
factory.setMapperLocations(this.properties.resolveMapperLocations());
}
Set<String> factoryPropertyNames = Stream
.of(new BeanWrapperImpl(SqlSessionFactoryBean.class).getPropertyDescriptors()).map(PropertyDescriptor::getName)
.collect(Collectors.toSet());
Class<? extends LanguageDriver> defaultLanguageDriver = this.properties.getDefaultScriptingLanguageDriver();
if (factoryPropertyNames.contains("scriptingLanguageDrivers") && !ObjectUtils.isEmpty(this.languageDrivers)) {
// Need to mybatis-spring 2.0.2+
factory.setScriptingLanguageDrivers(this.languageDrivers);
if (defaultLanguageDriver == null && this.languageDrivers.length == 1) {
defaultLanguageDriver = this.languageDrivers[0].getClass();
}
}
if (factoryPropertyNames.contains("defaultScriptingLanguageDriver")) {
// Need to mybatis-spring 2.0.2+
factory.setDefaultScriptingLanguageDriver(defaultLanguageDriver);
}
return factory.getObject();
}
private void applyConfiguration(SqlSessionFactoryBean factory) {
Configuration configuration = new Configuration();
factory.setConfiguration(configuration);
}
}
在传入数据源(此时完全动态了,比如前台传递数据源信息)之后,获取到SqlSessionFactory实例,然后通过以下方法获取MyBatis的mapper接口实例操作数据库了。
ColumnDefineMapper usedColumnDefineMapper = usedSqlSessionFactory.openSession().getMapper(ColumnDefineMapper.class);
但是这样操作一方面与MyBatis耦合,而且在业务层还需要每次去获取。所以考虑在Spring注入ColumnDefineMapper的时候注入一个代理对象进去,如果当前线程上下文包含有相关信息(比如DataSource或者SqlSessionFactory)的时候,就启用对应的DataSource。
import org.apache.ibatis.session.SqlSessionFactory;
import javax.sql.DataSource;
/**
* @author: guanglai.zhou
* @date: 2021/11/11 18:52
*/
public class MybatisSourceHolder {
private static final ThreadLocal<SqlSessionFactory> SQL_SESSION_FACTORY_THREAD_LOCAL = new ThreadLocal<>();
private static final ThreadLocal<DataSource> DATA_SOURCE_THREAD_LOCAL = new ThreadLocal<>();
public static void setSqlSessionFactoryThreadLocal(SqlSessionFactory sqlSessionFactory) {
SQL_SESSION_FACTORY_THREAD_LOCAL.set(sqlSessionFactory);
}
public static SqlSessionFactory getSqlSessionFactoryThreadLocal() {
return SQL_SESSION_FACTORY_THREAD_LOCAL.get();
}
public static void clearSqlSessionFactoryThreadLocal() {
SQL_SESSION_FACTORY_THREAD_LOCAL.remove();
}
public static void setDataSourceThreadLocal(DataSource dataSource) {
DATA_SOURCE_THREAD_LOCAL.set(dataSource);
}
public static DataSource getDataSourceThreadLocal() {
return DATA_SOURCE_THREAD_LOCAL.get();
}
public static void clearDataSourceThreadLocal() {
DATA_SOURCE_THREAD_LOCAL.remove();
}
public static void clearAll() {
SQL_SESSION_FACTORY_THREAD_LOCAL.remove();
DATA_SOURCE_THREAD_LOCAL.remove();
}
}
在bean创建过程中转为代理对象
import org.aopalliance.intercept.MethodInterceptor;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.mapper.MapperFactoryBean;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.Ordered;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;
import javax.sql.DataSource;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.*;
import java.util.stream.Collectors;
/**
* @author: guanglai.zhou
* @date: 2021/11/11 19:08
*/
@Component
public class MapperFactoryPostProcessor implements BeanDefinitionRegistryPostProcessor, Ordered, BeanPostProcessor, SmartInitializingSingleton, ApplicationContextAware {
public static final String GET_OBJECT = "getObject";
private Set<String> mapperBeanNames = new HashSet<>();
private SqlSessionFactoryProvider sqlSessionFactoryProvider;
public static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType";
private static Map<String, String> beanNameMapperInterfaceMapping = new HashMap<>();
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
String[] mapperFactoryBeanNames = beanFactory.getBeanNamesForType(MapperFactoryBean.class, false, false);
if (ArrayUtils.isEmpty(mapperFactoryBeanNames)) {
return;
}
for (String mapperFactoryBeanName : mapperFactoryBeanNames) {
String transformedBeanName = BeanFactoryUtils.transformedBeanName(mapperFactoryBeanName);
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(transformedBeanName);
// MapperFactoryBean
String factoryBeanObjectType = (String) beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE);
beanNameMapperInterfaceMapping.put(transformedBeanName, factoryBeanObjectType);
}
List<String> nameList = Arrays.stream(mapperFactoryBeanNames).collect(Collectors.toList());
mapperBeanNames.addAll(nameList);
for (String name : nameList) {
mapperBeanNames.add(BeanFactoryUtils.transformedBeanName(name));
}
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (mapperBeanNames.contains(beanName)) {
ProxyFactory proxyFactory = new ProxyFactory();
proxyFactory.setProxyTargetClass(true);
proxyFactory.setTargetClass(bean.getClass());
proxyFactory.setTarget(bean);
proxyFactory.addAdvice((MethodInterceptor) invocation -> {
Object[] arguments = invocation.getArguments();
Method method = invocation.getMethod();
if (GET_OBJECT.equals(method.getName()) && method.getParameterCount() == 0) {
Object object = invocation.proceed();
return Proxy.newProxyInstance(ClassUtils.getDefaultClassLoader(), new Class[]{getMapperInterface(beanName)}, (proxy, innerMethod, args) -> {
DataSource dataSourceThreadLocal = MybatisSourceHolder.getDataSourceThreadLocal();
if (dataSourceThreadLocal != null) {
SqlSessionFactory sqlSessionFactory = sqlSessionFactoryProvider.newInstance(dataSourceThreadLocal);
MybatisSourceHolder.setSqlSessionFactoryThreadLocal(sqlSessionFactory);
}
SqlSessionFactory currSqlSessionFactory = MybatisSourceHolder.getSqlSessionFactoryThreadLocal();
if (currSqlSessionFactory != null) {
SqlSession sqlSession = currSqlSessionFactory.openSession();
Object mapper = sqlSession.getMapper(getMapperInterface(beanName));
try {
return innerMethod.invoke(mapper, args);
}finally {
sqlSession.close();
}
});
}
return invocation.proceed();
});
return proxyFactory.getProxy();
}
return bean;
}
private Class getMapperInterface(String beanName) {
String className = beanNameMapperInterfaceMapping.get(beanName);
try {
return ClassUtils.forName(className, ClassUtils.getDefaultClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
@Override
public int getOrder() {
return -1;
}
private ApplicationContext applicationContext;
@Override
public void afterSingletonsInstantiated() {
sqlSessionFactoryProvider = applicationContext.getBean(SqlSessionFactoryProvider.class);
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
}
然后当要调用相关服务接口的时候,就往当前线程上下文中添加数据源即可。
以上源码涉及的SpringBoot等依赖版本如下
<spring-framework.version>5.3.8</spring-framework.version>
<spring-boot.version>2.5.2</spring-boot.version>
<spring-security.version>5.4.9</spring-security.version>
<mybatis.version>3.5.7</mybatis.version>
<mybatis-spring.version>2.0.6</mybatis-spring.version>
<mybatis-spring-boot.version>2.2.0</mybatis-spring-boot.version>
比如以下代码中的factoryBeanObjectType需要Spring版本在5.2之上才可以