项目场景描述:
相信好多的小伙伴都遇到过动态切换数据源的任务,这里我用到的是Druid,具体看项目要求,就好比我们这个项目吧,我们需要根据用户登录进来的时候获取他所在的公司地区,然后根据这个公司地区查询特定数据库,实现动态切换。
要求:
- 项目启动时已制定一个默认数据源。
- 系统不固定数据源,需要在访问时通过自定义的方式去指定或获取。
具体实施
1.导入maven依赖(主要的依赖如下:)
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>tk.mybatis</groupId>
<artifactId>mapper-spring-boot-starter</artifactId>
<version>2.1.5</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.1.16</version>
</dependency>
2. 数据源上下文DataSourceContextHolder
package net.cnki.common.datasource;
public class DataSourceContextHolder {
/**
* 使用ThreadLocal维护变量,ThreadLocal为每个使用该变量的线程提供独立的变量副本,
* 所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本。
*/
private static final ThreadLocal<String> contextHolder = new ThreadLocal<String>();
//设置当前线程持有的数据源
public static synchronized void setDBType(String dbType){
contextHolder.set(dbType);
}
public static String getDBType(){
return contextHolder.get();
}
public static void clearDBType(){
contextHolder.remove();
}
}
3.DynamicDataSource继承AbstractRoutingDataSource
重写了determineCurrentLookupKey()方法,在多个数据源中确定当前所需要使用的那一个。其实DynamicDataSource本身就是一个线程安全下的单例(单例本想用枚举,但是不可以继承,所以放弃了),dataSourceMap用于存储数据源信息。
package net.cnki.common.datasource;
import java.util.HashMap;
import java.util.Map;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
public class DynamicDataSource extends AbstractRoutingDataSource {
private static DynamicDataSource instance;
private static byte[] lock = new byte[0];
private static Map<Object, Object> dataSourceMap = new HashMap<Object, Object>();
@Override
public void setTargetDataSources(Map<Object, Object> targetDataSources) {
super.setTargetDataSources(targetDataSources);
dataSourceMap.putAll(targetDataSources);
super.afterPropertiesSet();
}
public Map<Object, Object> getDataSourceMap() {
return dataSourceMap;
}
public static synchronized DynamicDataSource getInstance() {
if (instance == null) {
synchronized (lock) {
if (instance == null) {
instance = new DynamicDataSource();
}
}
}
return instance;
}
// 必须实现其方法
protected Object determineCurrentLookupKey() {
return DataSourceContextHolder.getDBType();
}
}
4.Druid配置类(根据需要配置就可以 最大线程池最小线程池可以配置成通用的,就不需要动态获取了)
package net.cnki.common.datasource;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.Servlet;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.support.http.StatViewServlet;
import com.alibaba.druid.support.http.WebStatFilter;
@Configuration
public class DruidConfig {
@Value("${spring.datasource.type}")
private String db_type;
@Value("${spring.datasource.driver-class-name}")
private String db_driver_name;
@Value("${spring.datasource.url}")
private String db_url;
@Value("${spring.datasource.username}")
private String db_user;
@Value("${spring.datasource.password}")
private String db_pwd;
// 连接池初始化大小
@Value("${spring.datasource.initialSize}")
private int initialSize;
// 连接池最小值
@Value("${spring.datasource.minIdle}")
private int minIdle;
// 连接池最大 值
@Value("${spring.datasource.maxActive}")
private int maxActive;
// 配置获取连接等待超时的时间
@Value("${spring.datasource.maxWait}")
private int maxWait;
// 配置间隔多久才进行一次检测,检测需要关闭的空闲连接,单位是毫秒
@Value("${spring.datasource.timeBetweenEvictionRunsMillis}")
private int timeBetweenEvictionRunsMillis;
// 配置一个连接在池中最小生存的时间,单位是毫秒
@Value("${spring.datasource.minEvictableIdleTimeMillis}")
private int minEvictableIdleTimeMillis;
// 用来验证数据库连接的查询语句,这个查询语句必须是至少返回一条数据的SELECT语句
@Value("${spring.datasource.validationQuery}")
private String validationQuery;
// 检测连接是否有效
@Value("${spring.datasource.testWhileIdle}")
private boolean testWhileIdle;
// 申请连接时执行validationQuery检测连接是否有效。做了这个配置会降低性能。
@Value("${spring.datasource.testOnBorrow}")
private boolean testOnBorrow;
// 归还连接时执行validationQuery检测连接是否有效,做了这个配置会降低性能
@Value("${spring.datasource.testOnReturn}")
private boolean testOnReturn;
// 是否缓存preparedStatement,也就是PSCache。
@Value("${spring.datasource.poolPreparedStatements}")
private boolean poolPreparedStatements;
// 指定每个连接上PSCache的大小。
@Value("${spring.datasource.maxPoolPreparedStatementPerConnectionSize}")
private int maxPoolPreparedStatementPerConnectionSize;
// 配置监控统计拦截的filters,去掉后监控界面sql无法统计,'wall'用于防火墙
@Value("${spring.datasource.filters}")
private String filters;
// 通过connectProperties属性来打开mergeSql功能;慢SQL记录
@Value("${spring.datasource.connectionProperties}")
private String connectionProperties;
// Druid控制台配置:记录慢SQL
@Value("${spring.datasource.logSlowSql}")
private String logSlowSql;
@Value("${spring.datasource.removeAbandoned}")
private boolean removeAbandoned;
@Value("${spring.datasource.removeAbandonedTimeout}")
private int removeAbandonedTimeout;
@Value("${spring.datasource.logAbandoned}")
private boolean logAbandoned;
@Bean
public DynamicDataSource druidDataSource() {
Map<Object,Object> map = new HashMap<>();
DynamicDataSource dynamicDataSource = DynamicDataSource.getInstance();
DruidDataSource defaultDataSource = new DruidDataSource();
defaultDataSource.setDriverClassName(db_driver_name);
defaultDataSource.setUrl(db_url);
defaultDataSource.setUsername(db_user);
defaultDataSource.setPassword(db_pwd);
defaultDataSource.setInitialSize(initialSize);
defaultDataSource.setMinIdle(minIdle);
defaultDataSource.setMaxActive(maxActive);
defaultDataSource.setMaxWait(maxWait);
defaultDataSource.setTimeBetweenEvictionRunsMillis(timeBetweenEvictionRunsMillis);
defaultDataSource.setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis);
defaultDataSource.setValidationQuery(validationQuery);
defaultDataSource.setTestWhileIdle(testWhileIdle);
defaultDataSource.setTestOnBorrow(testOnBorrow);
defaultDataSource.setTestOnReturn(testOnReturn);
defaultDataSource.setPoolPreparedStatements(poolPreparedStatements);
defaultDataSource.setMaxPoolPreparedStatementPerConnectionSize(maxPoolPreparedStatementPerConnectionSize);
defaultDataSource.setRemoveAbandoned(removeAbandoned);
defaultDataSource.setRemoveAbandonedTimeout(removeAbandonedTimeout);
defaultDataSource.setLogAbandoned(logAbandoned);
dynamicDataSource.setDefaultTargetDataSource(defaultDataSource);
map.put("default", defaultDataSource);
dynamicDataSource.setTargetDataSources(map);
dynamicDataSource.setDefaultTargetDataSource(defaultDataSource);
return dynamicDataSource;
}
@Bean
public ServletRegistrationBean<Servlet> druid(){
// 现在要进行druid监控的配置处理操作
ServletRegistrationBean<Servlet> servletRegistrationBean = new ServletRegistrationBean<>(new StatViewServlet(), "/druid/*");
// 白名单,多个用逗号分割, 如果allow没有配置或者为空,则允许所有访问
servletRegistrationBean.addInitParameter("allow", "127.0.0.1");
// 黑名单,多个用逗号分割 (共同存在时,deny优先于allow)
//servletRegistrationBean.addInitParameter("deny", "192.168.1.110");
// 控制台管理用户名
servletRegistrationBean.addInitParameter("loginUsername", "admin");
// 控制台管理密码
servletRegistrationBean.addInitParameter("loginPassword", "admin");
// 是否可以重置数据源,禁用HTML页面上的“Reset All”功能
servletRegistrationBean.addInitParameter("resetEnable", "false");
return servletRegistrationBean;
}
@Bean
public FilterRegistrationBean<Filter> filterRegistrationBean() {
FilterRegistrationBean<Filter> filterRegistrationBean = new FilterRegistrationBean<>() ;
filterRegistrationBean.setFilter(new WebStatFilter());
//所有请求进行监控处理
filterRegistrationBean.addUrlPatterns("/*");
//添加不需要忽略的格式信息
filterRegistrationBean.addInitParameter("exclusions", "*.js,*.gif,*.jpg,*.css,/druid/*");
return filterRegistrationBean ;
}
}
DruidDataSourceUtil实现切换入口
需要进行切换时调用此方法即可
package net.cnki.common.datasource;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.alibaba.druid.pool.DruidDataSource;
public class DruidDataSourceUtil {
private static Logger logger = LoggerFactory.getLogger(DruidDataSourceUtil.class);
public static void addOrChangeDataSource(String key,String dbip,String dbname,String dbuser,String dbpwd){
DataSourceContextHolder.setDBType("default");
/**
* 创建动态数据源
*/
Map<Object, Object> dataSourceMap = DynamicDataSource.getInstance().getDataSourceMap();
if(!dataSourceMap.containsKey(key+"master") && null != key){
logger.info("插入新数据库连接信息为:jdbc:mysql://"+dbip+":3306/"+dbname+"?serverTimezone=Hongkong&characterEncoding=UTF-8&useSSL=false");
DruidDataSource dynamicDataSource = new DruidDataSource();
dynamicDataSource.setDriverClassName("com.mysql.cj.jdbc.Driver");
dynamicDataSource.setUsername(dbuser);
dynamicDataSource.setUrl("jdbc:mysql://"+dbip+":3306/"+dbname+"?serverTimezone=Hongkong&characterEncoding=UTF-8&useSSL=false&nullCatalogMeansCurrent=true&allowMultiQueries=true"
);
dynamicDataSource.setPassword(dbpwd);
dynamicDataSource.setInitialSize(50);
dynamicDataSource.setMinIdle(5);
dynamicDataSource.setMaxActive(1000);
dynamicDataSource.setMaxWait(5000);
dynamicDataSource.setTimeBetweenEvictionRunsMillis(60000);
dynamicDataSource.setMinEvictableIdleTimeMillis(300000);
dynamicDataSource.setValidationQuery("SELECT 1 FROM DUAL");
dynamicDataSource.setTestWhileIdle(true);
dynamicDataSource.setTestOnBorrow(false);
dynamicDataSource.setTestOnReturn(false);
dynamicDataSource.setPoolPreparedStatements(true);
dynamicDataSource.setMaxPoolPreparedStatementPerConnectionSize(20);
dynamicDataSource.setRemoveAbandoned(true);
dynamicDataSource.setRemoveAbandonedTimeout(180);
dynamicDataSource.setLogAbandoned(true);
dataSourceMap.put(key+"master", dynamicDataSource);
DynamicDataSource.getInstance().setTargetDataSources(dataSourceMap);
//切换为动态数据源实例
DataSourceContextHolder.setDBType(key+"master");
}else{
//切换为动态数据源实例
DataSourceContextHolder.setDBType(key+"master");
}
}
}
切换
DruidDataSourceUtil.addOrChangeDataSource(key,dbip,dbname,dbuser,dbpwd);
3.1 全局切换
全局切换也就是每次访问都会切换数据源,不需要考虑到底是哪些接口。
继承OncePerRequestFilte。
注意:这里都是一些业务逻辑校验jwt,过滤url等,可以直接忽略,如切换数据源的数量不是很大,可以直接根据用户传过来的信息,进行一一匹配然后把数据源写死即可。
package net.cnki.security.filter;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.stereotype.Component;
import org.springframework.util.PathMatcher;
import org.springframework.web.filter.OncePerRequestFilter;
import com.alibaba.fastjson.JSONObject;
import net.cnki.api.cnki.bean.CasDbBean;
import net.cnki.common.returned.ResponseUtil;
import net.cnki.common.returned.ResultCode;
import net.cnki.common.returned.ResultGenerator;
import net.cnki.common.datasource.DruidDataSourceUtil;
import net.cnki.common.redis.JedisUtils;
import net.cnki.common.redis.RedisConstants;
import net.cnki.security.jwt.JwtTokenUtil;
import net.cnki.util.AESUtil;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
@Component
public class JwtAuthenticationTokenFilter extends OncePerRequestFilter {
Logger logger = LoggerFactory.getLogger(this.getClass());
@Autowired
private UserDetailsService userDetailsService;
@Autowired
ResultGenerator resultGenerator;
@Autowired
private PathMatcher pathMatcher;
@Autowired
private JwtTokenUtil jwtTokenUtil;
@Autowired
JedisUtils jedisUtils;
@Value("${jwt.token.header}")
private String token_header;
@Value("${jwt.token.type}")
private String token_type;
@Value("${jwt.token.passUrl}")
private List<String> passUrl;
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
String requestUrl = request.getRequestURI();
logger.info("["+requestUrl+"]访问校验jwt,并将用户角色信息写入内存!");
//判断URL是否需要验证
Boolean flag = true;
for(String url : passUrl){
if(pathMatcher.match(url, requestUrl)){
flag = false;
break;
}
}
//根据判断结果执行校验
if (flag) {
String authHeader = request.getHeader(this.token_header);
if (authHeader != null && authHeader.startsWith(this.token_type)) {
//获取token
String authToken = authHeader.substring(this.token_type.length());
if (!jwtTokenUtil.isTokenExpired(authToken)) {//无效token去更新
//根据token获取用户名
String username = jwtTokenUtil.getUserNameFromToken(authToken);
if (username != null) {
String retoken = jedisUtils.get(username, RedisConstants.datebase1);
if (StringUtils.isEmpty(retoken)) {
logger.error("用户:"+username+" 访问url:["+requestUrl+"]校验失败,未登录!");
ResponseUtil.out(response, 402, resultGenerator.getFreeResult(ResultCode.LOGIN_NO).toString());
return;
}
//获取用户对应数据源
String dbStr = jedisUtils.get(username,RedisConstants.datebase2);
if (dbStr != null) {
String dbinfo = AESUtil.decryptPwd(dbStr);
CasDbBean casDbBean = JSONObject.parseObject(dbinfo, CasDbBean.class);
DruidDataSourceUtil.addOrChangeDataSource(casDbBean.getSchoolId(),casDbBean.getDbIp(),casDbBean.getDbName(),casDbBean.getDbUser(),casDbBean.getDbPassword());
}
UserDetails userDetails = this.userDetailsService.loadUserByUsername(username);
if (jwtTokenUtil.validateToken(authToken, userDetails) && !StringUtils.isEmpty(retoken)) {
//验证token是否有效
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
authentication.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
SecurityContextHolder.getContext().setAuthentication(authentication);
chain.doFilter(request, response);
return;
}
}
}
}
}else {//无需校验直接通过
chain.doFilter(request, response);
return;
}
logger.error("访问url:["+requestUrl+"]校验失败,无权访问!");
ResponseUtil.out(response, 403, resultGenerator.getFreeResult(ResultCode.NO_PERMISSION).toString());
}
}