在开发过程中经常会使用到多数据,比如一个框架库多个业务库、读写分离等。下面我在 Springboot、Druid、mybatis 的基础上,进行动态切换数据库封装(代码可能会存在性能等方面的问题,希望大家多多指出)。
思路步骤:
- 在 springboot 配置文件中,进行多数据源数据库配置。
- 根据配置文件进行数据源初始化,获取对应的 SQLSessionFactroy。
- 定义数据源标识注解,用于标识该 mapper 属于哪个数据源。
- 定义 SessionFactroy,用于获取 mapper 对象
- 创建 DBConnection 用于管理数据库连接、事务提交。
- 测试
下面代码采用 mysql 数据库进行演示,小伙伴们可以根据自己的需要进行相关文件修改,有不当之处,欢迎评论指教。
1.引入相关依赖
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>2.1.3</version>
</dependency>
<!--引入druid数据源 -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.1.8</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<scope>runtime</scope>
</dependency>
2.yml数据库配置
# 数据库配置
spring:
# 数据源名称,多个以逗号分隔
datasourceNames: master,cluster
# 设置默认数据源(为空,默认为第一个数据源)
datasourceDefault: cluster
datasourceItems:
# 框架库(不可修改数据源名称)
master:
url: jdbc:mysql://ip:端口/数据库名1?useUnicode=true&characterEncoding=utf-8&useSSL=false
username: test1
password: test1
driverClassName: com.mysql.cj.jdbc.Driver
# 初始化创建的连接数
initialSize: 10
# 最小空闲连接数
minIdle: 10
# 最大连接数量,连接数连不能超过该值
maxActive: 2000
# 超时等待时间(毫秒),当连接超过该时间便认为其实空闲连接
maxWait: -1
# 申请连接的时候检测,如果空闲时间大于timeBetweenEvictionRunsMillis,执行validationQuery检测连接是否有效。
testWhileIdle: true
timeBetweenEvictionRunsMillis: 60000
# 物理连接初始化的时候执行的sql
connectionInitSqls: select now()
cluster:
url: jdbc:mysql://ip:端口/数据库名2?useUnicode=true&characterEncoding=utf-8&useSSL=false
username: test2
password: test2
driverClassName: com.mysql.cj.jdbc.Driver
# 初始化创建的连接数
initialSize: 10
# 最小空闲连接数
minIdle: 10
# 最大连接数量,连接数连不能超过该值
maxActive: 2000
# 超时等待时间(毫秒),当连接超过该时间便认为其实空闲连接
maxWait: -1
# 申请连接的时候检测,如果空闲时间大于timeBetweenEvictionRunsMillis,执行validationQuery检测连接是否有效。
testWhileIdle: true
timeBetweenEvictionRunsMillis: 60000
# 物理连接初始化的时候执行的sql
connectionInitSqls: select now()
3.初始化数据源
@Configuration
public class DataSourceConfig {
private static final Logger log = LoggerFactory.getLogger(DataSourceConfig.class);
// 默认数据源名称
public static String DATASOURCE_DEFAULT;
// 存储SqlSessionFactory
public static transient Map<String, SqlSessionFactory> sqlSessionFactoryMap = new HashMap<>();
@Autowired
private Environment environment;
@PostConstruct
public void init() {
log.info("系统初始化数据源开始");
// 获取所有数据源名称
String datasourceNames = environment.getProperty("spring.datasourceNames");
// 获取默认数据源名称
String datasourceDefault = environment.getProperty("spring.datasourceDefault");
if (!StringUtils.isBlank(datasourceNames)) {
List<String> dbNames = Splitter.on(",").trimResults().omitEmptyStrings().splitToList(datasourceNames);
try {
if (StringUtils.isBlank(datasourceDefault)) {
DataSourceConfig.DATASOURCE_DEFAULT = dbNames.get(0);
} else if (!dbNames.contains(datasourceDefault)) {
throw new Exception("设置默认数据源" + datasourceDefault + "不存在");
} else {
DataSourceConfig.DATASOURCE_DEFAULT = datasourceDefault;
}
String prefix;
for (String dbName : dbNames) {// 注册DataSource
log.info("初始化数据源:{}", dbName);
prefix = "spring.datasourceItems." + dbName;
DruidDataSource dataSource = new DruidDataSource();
// 是否自动提交
dataSource.setDefaultAutoCommit(false);
dataSource.setUrl(environment.getProperty(prefix + ".url"));
dataSource.setUsername(environment.getProperty(prefix + ".username"));
dataSource.setPassword(environment.getProperty(prefix + ".password"));
dataSource.setDriverClassName(environment.getProperty(prefix + ".driverClassName"));
Optional.ofNullable(environment.getProperty(prefix + ".initialSize")).ifPresent(data -> dataSource.setInitialSize(Integer.valueOf(data)));
Optional.ofNullable(environment.getProperty(prefix + ".minIdle")).ifPresent(data -> dataSource.setMinIdle(Integer.valueOf(data)));
Optional.ofNullable(environment.getProperty(prefix + ".maxActive")).ifPresent(data -> dataSource.setMaxActive(Integer.valueOf(data)));
Optional.ofNullable(environment.getProperty(prefix + ".maxWait")).ifPresent(data -> dataSource.setMaxWait(Integer.valueOf(data)));
Optional.ofNullable(environment.getProperty(prefix + ".testWhileIdle")).ifPresent(data -> dataSource.setTestWhileIdle(Boolean.parseBoolean(data)));
Optional.ofNullable(environment.getProperty(prefix + ".timeBetweenEvictionRunsMillis")).ifPresent(data -> dataSource.setTimeBetweenEvictionRunsMillis(Long.valueOf(data)));
dataSource.setValidationQuery(environment.getProperty(prefix + ".connectionInitSqls"));
// 初始化
dataSource.init();
// 存储sqlSessionFactor
SqlSessionFactoryBean sessionFactory = new SqlSessionFactoryBean();
sessionFactory.setDataSource(dataSource);
sessionFactory.setMapperLocations(new PathMatchingResourcePatternResolver().getResources("classpath*:/com/lzq/tortoise/**/*.xml"));
SqlSessionFactory sqlSessionFactory = sessionFactory.getObject();
sqlSessionFactoryMap.put(dbName, sqlSessionFactory);
log.info("数据源{}初始化成功", dbName);
}
} catch (Exception e) {
log.error("系统初始化数据源失败", e);
e.printStackTrace();
}
}
log.info("系统初始化数据源结束");
}
**4.定义标识注解EnvDataSource **
/**
* 数据源标注类
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface EnvDataSource {
/**
* 数据库名称
* @return
*/
String value();
}
5.编写 SessionFactory
public final class SessionFactory {
/**
* 用来存储数据源对应的SQLSession
*/
private Map<String, SqlSession> evnSqlSessionMap = new HashMap<>();
public <T> T getMapper(Class<T> tClass) throws AppException {
String dbName;
// 获取当前mapper EnvDataSource注解
EnvDataSource envDataSource = tClass.getAnnotation(EnvDataSource.class);
if (envDataSource != null && !StringUtils.isBlank(envDataSource.value())) {
dbName = envDataSource.value();
} else {
dbName = DataSourceConfig.DATASOURCE_DEFAULT;
}
SqlSession session = evnSqlSessionMap.get(dbName);
if (session == null) {
session = getSqlSession(dbName);
evnSqlSessionMap.put(dbName, session);
}
return session.getMapper(tClass);
}
/**
* 提交所有连接,并关闭所有连接。
*/
public void close() {
close(true);
}
/**
* 关闭所有session
*
* @param commit 是否执行提交命令。为false执行rollback。
*/
public void close(boolean commit) {
Collection<SqlSession> sqlSessions = evnSqlSessionMap.values();
if (sqlSessions != null && sqlSessions.size() > 0) {
sqlSessions.stream().forEach(sqlSession -> {
if (commit) {
sqlSession.commit();
} else {
sqlSession.rollback();
}
sqlSession.close();
});
}
evnSqlSessionMap.clear();
}
/**
* 根据数据源名称获取SQLSession
*
* @param dbName
* @return
*/
private SqlSession getSqlSession(String dbName) throws AppException {
SqlSessionFactory sqlSessionFactory = DataSourceConfig.sqlSessionFactoryMap.get(dbName);
if (sqlSessionFactory == null) {
throw new AppException("没有获取到数据源" + dbName);
}
return sqlSessionFactory.openSession();
}
}
6.编写DBConnection
/**
* 定义数据库执行任务接口
*/
public interface DBTask {
Object run(SessionFactory factory) throws AppException;
}
/**
* 数据库连接
*/
public final class DBConnection {
private DBConnection() {
}
public static final <T> T runTask(DBTask task, Class<T> tClass) throws AppException {
if (tClass == null)
throw new NullPointerException(tClass.getName() + " Class can't be null.");
T t;
SessionFactory session = new SessionFactory();
try {
Object o = task.run(session);
if (o == null){
t = null;
}else if (tClass.isInstance(o)) {
t = (T) o;
} else {
throw new IllegalArgumentException("返回的参数异常");
}
session.close();
} catch (Exception ex) {
session.close(false);
throw new AppException(ex);
}
return t;
}
}
7.测试
@EnvDataSource("master")
public interface TestJDBC1Mapper {
// @Select("SELECT * FROM seq")
List<Map> get();
Integer insert(Map map);
}
@EnvDataSource("cluster")
public interface TestJDBC2Mapper {
// @Select("SELECT * FROM aa10")
List<Map> get();
Integer insert(Map map);
}
@Override
public Object threadDB(String index) throws AppException {
return DBConnection.runTask(factory -> {
Map map = new HashMap();
map.put("id", "test00" + index);
map.put("username", "test00" + index);
map.put("loginid", "test00" + index);
map.put("password", "test00" + index);
map.put("phone", "test00" + index);
Map returnMap = new HashMap();
// 新增
returnMap.put("testJDBC1Mapper", factory.getMapper(TestJDBC1Mapper.class).insert(map));
if (Integer.parseInt(index) % 10 == 0) {
int i = 1 / 0;
}
// 新增
returnMap.put("testJDBC2Mapper", factory.getMapper(TestJDBC2Mapper.class).insert(map));
return returnMap;
}, Object.class);
}