背景:
项目中在使用多数据源自动配置功能时,遇到以下几个问题,比如防sql注入需要自定义拦截器,新老版本PaginationInterceptor过期,如何缓存Entity和Mapper对应关系,如何定义SQL注入器等等。
常见的数据源自动配置代码大概如下:
package com.xxx.conf;
import com.alibaba.druid.pool.DruidDataSource;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.sql.DataSource;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.SqlSessionTemplate;
import org.mybatis.spring.annotation.MapperScan;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
@Configuration
@MapperScan(basePackages = {"com.xxx.dao", "com.xxx.mapper"}, sqlSessionFactoryRef = "baseSqlSessionFactory")
public class DataSourceConfiguration {
private static final Logger logger = LoggerFactory.getLogger(DataSourceConfiguration.class);
@Value("${spring.datasource.baseProperties.url}")
private String dbUrl;
@Value("${spring.datasource.baseProperties.username}")
private String username;
@Value("${spring.datasource.baseProperties.password}")
private String password;
@Value("${spring.datasource.driverClassName}")
private String driverClassName;
@Value("${spring.datasource.initialSize}")
private int initialSize;
@Value("${spring.datasource.minIdle}")
private int minIdle;
@Value("${spring.datasource.maxActive}")
private int maxActive;
@Value("${spring.datasource.maxWait}")
private int maxWait;
@Value("${spring.datasource.timeBetweenEvictionRunsMillis}")
private int timeBetweenEvictionRunsMillis;
@Value("${spring.datasource.minEvictableIdleTimeMillis}")
private int minEvictableIdleTimeMillis;
@Value("${spring.datasource.validationQuery}")
private String validationQuery;
@Value("${spring.datasource.testWhileIdle}")
private boolean testWhileIdle;
@Value("${spring.datasource.testOnBorrow}")
private boolean testOnBorrow;
@Value("${spring.datasource.testOnReturn}")
private boolean testOnReturn;
@Value("${spring.datasource.poolPreparedStatements}")
private boolean poolPreparedStatements;
@Value("${spring.datasource.maxPoolPreparedStatementPerConnectionSize}")
private int maxPoolPreparedStatementPerConnectionSize;
@Value("${spring.datasource.filters}")
private String filters;
@Value("${spring.datasource.connectionProperties}")
private String connectionProperties;
@Value("${mybatis.base.typeAliasesPackage}")
private String typeAliasesPackage;
@Value("${mybatis.base.configLocation}")
private String configLocation;
@Value("${mybatis.base.mapperLocations}")
private String mapperLocations;
@Bean(name = "baseDatasource")
@Primary
public DataSource baseDatasource() throws Exception {
DruidDataSource datasource = new DruidDataSource();
datasource.setUrl(dbUrl);
datasource.setUsername(username);
datasource.setPassword(password);
datasource.setDriverClassName(driverClassName);
// configuration
datasource.setInitialSize(initialSize);
datasource.setMinIdle(minIdle);
datasource.setMaxActive(maxActive);
datasource.setMaxWait(maxWait);
datasource.setTimeBetweenEvictionRunsMillis(timeBetweenEvictionRunsMillis);
datasource.setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis);
datasource.setValidationQuery(validationQuery);
datasource.setTestWhileIdle(testWhileIdle);
datasource.setTestOnBorrow(testOnBorrow);
datasource.setTestOnReturn(testOnReturn);
datasource.setPoolPreparedStatements(poolPreparedStatements);
datasource.setMaxPoolPreparedStatementPerConnectionSize(maxPoolPreparedStatementPerConnectionSize);
try {
datasource.setFilters(filters);
} catch (SQLException e) {
logger.error("druid configuration initialization filter", e);
}
datasource.setConnectionProperties(connectionProperties);
return datasource;
}
@Bean("baseSqlSessionFactory")
@Primary
public SqlSessionFactory baseSqlSessionFactory(@Qualifier("baseDatasource") DataSource dataSource) throws Exception {
SqlSessionFactoryBean sessionFactoryBean = new SqlSessionFactoryBean();
sessionFactoryBean.setDataSource(dataSource);
sessionFactoryBean.setTypeAliasesPackage(typeAliasesPackage);
sessionFactoryBean.setConfigLocation(new DefaultResourceLoader().getResource(configLocation));
try {
sessionFactoryBean.setMapperLocations(resolveMapperLocations());
return sessionFactoryBean.getObject();
} catch (Exception e) {
logger.error(e.getMessage(), e);
return null;
}
}
@Bean(name = "baseTransactionManager")
@Primary
public DataSourceTransactionManager transactionManager(@Qualifier("baseDatasource") DataSource dataSource) {
return new DataSourceTransactionManager(dataSource);
}
@Bean(name = "baseSqlSessionTemplate")
@Primary
public SqlSessionTemplate baseSqlSessionTemplate(@Qualifier("baseSqlSessionFactory") SqlSessionFactory sqlSessionFactory) {
return new SqlSessionTemplate(sqlSessionFactory);
}
/**
* 解析多个mapperLocations
*
* @return
*/
private Resource[] resolveMapperLocations() {
ResourcePatternResolver resourceResolver = new PathMatchingResourcePatternResolver();
List<Resource> resources = new ArrayList<>();
for (String mapperLocation : mapperLocations.split(",")) {
try {
Resource[] mappers = resourceResolver.getResources(mapperLocation);
resources.addAll(Arrays.asList(mappers));
} catch (IOException e) {
logger.error("Fail to get mybatis resources, due to {}", e.getMessage());
}
}
return resources.toArray(new Resource[resources.size()]);
}
}
1、自定义拦截器
前景介绍:在做防sql注入时,比如对于传入sql:select * from xx_table where user_id = ?,此处使用?作为占位符,通过Mybatis较验引擎进行参数赋值(类似JDBC setParam)
通用接口:
List<T> queryBySqlAndParams(@Param("sqlParams") List<Object> var1, @Param("sql") String var2);
查询拦截器:
package com.xxx.interceptor;
import java.sql.PreparedStatement;
import java.sql.Statement;
import java.util.List;
import java.util.Properties;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
/**
* queryBySql + params列表通用查询拦截器
**/
@Intercepts(value = {@Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class})})
public class QueryInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
// 对于sql模板入参需要保证多个mapper方法参数,且注入参数列表放在第一个参数位置
// List<T> queryBySqlAndParams(@Param("sqlParams") List<Object> params, @Param("sql") String sql)
// 对于直接传入的sql和只有一个参数的不做处理
if (boundSql.getParameterObject() instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap parameterObject = (MapperMethod.ParamMap) boundSql.getParameterObject();
if (parameterObject.containsKey("sqlParams") && parameterObject.get("sqlParams") != null && parameterObject.get("sqlParams") instanceof List) {
Object arg1 = invocation.getArgs()[0];
PreparedStatement statement = (PreparedStatement) arg1;
int i = 1;
List<Object> params = ((List<Object>) parameterObject.get("sqlParams"));
for (Object obj : params) {
statement.setObject(i, obj);
i++;
}
}
}
return invocation.proceed();
}
@Override
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
@Override
public void setProperties(Properties properties) {
}
}
拦截器配置:此处使用MyBatis-Plus中的MybatisSqlSessionFactoryBean创建 SqlSessionFactory
MybatisSqlSessionFactoryBean fb = new MybatisSqlSessionFactoryBean();
fb.setDataSource(dataSource);
fb.setPlugins(new QueryInterceptor());
2、PaginationInterceptor 拦截器@Deprecated问题
前景介绍:在使用PaginationInterceptor时会出现过期使用的问题
如何解决呢?还是结合Mybatis-plus
@Bean
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
interceptor.addInnerInterceptor(new PaginationInnerInterceptor());
return interceptor;
}
最后拦截器配置:
@Bean(name = "sqlSessionFactory")
@Primary
public SqlSessionFactory sqlSessionFactory(@Qualifier("dataSource") DataSource dataSource, MybatisPlusInterceptor interceptor) throws Exception {
MybatisSqlSessionFactoryBean fb = new MybatisSqlSessionFactoryBean();
fb.setDataSource(dataSource);
fb.setConfigLocation(xxx);
fb.setTypeAliasesPackage(xxx);
fb.setPlugins(interceptor);
return fb.getObject();
}
3、如何缓存所有Mapper Class Name和Mapper
前景介绍:项目中使用到了统一的dbwriter写入服务,dbwriter服务会加载所有业务的EntityMapper
所以在服务启动的时候需要知道所有EntityMapper和Mapper Proxy的对应关系,方便使用Entity获取的BaseMapper,从而操作数据库
如何缓存entity和mapper的对应关系呢?
前提是使用到了 SqlSessionTemplate进行Mapper Package扫描
最后如何获取?
@Autowired
private List<SqlSessionTemplate> sqlSessionTemplateList;
sqlSessionTemplateList.forEach(item -> {
try {
MapperRegistry mapperRegistry = item.getConfiguration().getMapperRegistry();
Collection<Class<?>> mappers = mapperRegistry.getMappers();
mappers.forEach(eachMapper -> {
BaseMapper mapper = (BaseMapper) mapperRegistry.getMapper(eachMapper, item);
String className = eachMapper.getName();
entityMapperMap.put(className, mapper);
});
} catch (Exception e) {
logger.error("resolve exception!", e);
}
});
4、自定义SQL注入器方法
前景介绍:项目中在使用MyBatis-Plus开发的时候往往BaseMapper中提供的方法无法都完全满足,需要自定义一些扩展方法,比如自定义动态sql查询,动态批量插入sql等等。
如何自定义?
首先还是继承BaseMapper设计自己的MyMapper接口。
public interface MyMapper<T> extends BaseMapper<T> { List<T> queryBySql(String var1); }
其次配置具体queryBySql执行的方法
@Slf4j
public class QueryBySql extends AbstractMethod {
private final String functionName = "queryBySql";
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
return this.addSelectMappedStatementForTable(mapperClass, functionName, sqlSource, tableInfo);
}
@Override
public String getMethod(SqlMethod sqlMethod) {
return functionName;
}
}
最后注入器配置
@Bean(name = "sqlSessionFactory")
@Primary
public SqlSessionFactory sqlSessionFactory(@Qualifier("dataSource") DataSource dataSource, MybatisPlusInterceptor interceptor) throws Exception {
MybatisSqlSessionFactoryBean fb = new MybatisSqlSessionFactoryBean();
fb.setDataSource(dataSource);
fb.setConfigLocation(xxx);
fb.setTypeAliasesPackage(xxx);
fb.setPlugins(interceptor);
GlobalConfig globalConfig = new GlobalConfig();
globalConfig.setSqlInjector(new MySqlInjector());
GlobalConfig.DbConfig dbConfig = new GlobalConfig.DbConfig();
dbConfig.setIdType(IdType.AUTO);
globalConfig.setDbConfig(dbConfig);
fb.setGlobalConfig(globalConfig);
return fb.getObject();
}