Mybatis-Plus是在Mybatis持久层框架上封装的一层非常好用的工具,最近因为想要在Mapper里加入自己自定义的通用方法,所以用到了Mybatis-Plus的Sql注入器。Sql注入器的作用是可以实现自定义的sql脚本并注入到MappedStatement里,从而达到动态拼装sql并生成Mapper接口的目的。这种方式与自己写一个通用Mapper的不同在于,Mybatis-Plus提供的AbstractMethod方法类,实现的接口里可以获取到表信息,我们可以利用它们做批量插入和批量更新的sql拼装。同时,Mybatis-Plus也提供了自带的一些AbstractMethod实现类。下面我以批量更新和批量插入两个示例,贴出代码供大家参考。
自定义批量更新方法 UpdateBatchMethod.java。重载injectMappedStatement方法,此方法可以生成拼接批量更新sql的脚本。
import cn.hutool.db.Entity;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
public class UpdateBatchMethod extends AbstractMethod {
/**
update user set
name=
(CASE
WHEN id=1 THEN '张三'
WHEN id=2 THEN '李四'
end),
age =
(CASE
WHEN id=1 THEN '2'
WHEN id=2 THEN '2'
end) where id in (1,2);
<script>
update user
<trim prefix="set" suffixOverrides=",">
<trim prefix="name =(case" suffix="end),">
<foreach collection="list" item="item" >
<if test="item.name!=null">
when id=#{item.id} then #{item.name}
</if>
</foreach>
else name
</trim>
<trim prefix="age =(case" suffix="end),">
<foreach collection="list" item="item" >
<if test="item.age!=null">
when id=#{item.id} then #{item.age}
</if>
</foreach>
else age</trim>
</trim>
where id in
<foreach collection="list" item="item" separator="," open="(" close=")">
#{item.id}
</foreach>
</script>
*/
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
final String sql = "<script>\n update %s %s \n where id in \n <foreach collection=\"list\" item=\"item\" separator=\",\" open=\"(\" close=\")\">\n #{item.id} </foreach> \n </script>";
final String valueSql = prepareValuesSql(tableInfo);
final String sqlResult = String.format(sql, tableInfo.getTableName(), valueSql);
SqlSource sqlSource = languageDriver.createSqlSource(configuration, sqlResult, modelClass);
return this.addUpdateMappedStatement(mapperClass, modelClass, "updateBatch", sqlSource);
}
private String prepareValuesSql(TableInfo tableInfo) {
final StringBuilder valueSql = new StringBuilder();
valueSql.append("<trim prefix=\"set\" suffixOverrides=\",\">\n");
tableInfo.getFieldList().forEach(x -> {
valueSql.append("<trim prefix=\"").append(x.getColumn()).append(" =(case \" suffix=\"end),\">\n");
valueSql.append("<foreach collection=\"list\" item=\"item\" >\n");
valueSql.append("when id=#{item.id} then ifnull(#{item.").append(x.getProperty()).append("},").append(x.getColumn()).append(")\n");
valueSql.append("</foreach>\n");
valueSql.append("else ").append(x.getColumn());
valueSql.append("</trim>\n");
});
valueSql.append("</trim>\n");
return valueSql.toString();
}
}
自定义Sql注入器 InsertBatchSqlInjector.java。将上面的批量更新方法对象添加到默认sql注入器的方法列表。一同添加的还有mybatis-plus自带的批量新增方法。
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.injector.DefaultSqlInjector;
import com.baomidou.mybatisplus.extension.injector.methods.additional.InsertBatchSomeColumn;
import java.util.List;
public class InsertBatchSqlInjector extends DefaultSqlInjector {
@Override
public List<AbstractMethod> getMethodList(Class<?> mapperClass) {
List<AbstractMethod> methodList = super.getMethodList(mapperClass);
methodList.add(new InsertBatchSomeColumn());
methodList.add(new UpdateBatchMethod());
return methodList;
}
}
注入Sql注入器 MybatisPlusConfig.java。将上面我们自定义的sql注入器注入到Spring容器里。
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class MybatisPlusConfig {
@Bean
public InsertBatchSqlInjector insertBatchSqlInjector() {
return new InsertBatchSqlInjector();
}
}
自定义BaseMapper。
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
public interface MyBaseMapper<T> extends BaseMapper<T> {
// 批量插入
int insertBatchSomeColumn(@Param("list") List<T> batchList);
// 批量更新
int updateBatch(@Param("list") List<T> list);
}
然后,在业务Mapper对象上,将继承类从BaseMapper改为上面我们创建好的MyBaseMapper
自定义ServiceImpl
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.smcaiot.wcpa.safedata.mapper.MyBaseMapper;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
public class MyServiceImpl<M extends MyBaseMapper<T>, T> extends ServiceImpl<M, T> {
@Override
public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
if (CollUtil.isEmpty(entityList)) {
return false;
}
List<T> updates = entityList.stream().filter(x -> !isIdNull(x)).collect(Collectors.toList());
List<T> inserts = entityList.stream().filter(x -> isIdNull(x)).collect(Collectors.toList());
int count = 0;
List<T> tmpList = new ArrayList<>();
if (CollUtil.isNotEmpty(inserts)) {
for (T insert : inserts) {
int i = tmpList.size();
if (i >= 1 && i % batchSize == 0) {
count += getBaseMapper().insertBatchSomeColumn(tmpList);
tmpList.clear();
}
tmpList.add(insert);
}
count += getBaseMapper().insertBatchSomeColumn(tmpList);
tmpList.clear();
}
if (CollUtil.isNotEmpty(updates)) {
for (T update : updates) {
int i = tmpList.size();
if (i >= 1 && i % batchSize == 0) {
count += getBaseMapper().updateBatch(tmpList);
tmpList.clear();
}
tmpList.add(update);
}
count += getBaseMapper().updateBatch(tmpList);
tmpList.clear();
}
return count > 0;
}
private boolean isIdNull(Object obj) {
return Objects.isNull(BeanUtil.getProperty(obj, "id"));
}
}
将业务Service的继承类改成上面我们生成的MyServiceImpl.java
单元测试配置
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.smcaiot.wcpa.safedata.app.InsertBatchSqlInjector;
import com.smcaiot.wcpa.safedata.app.MybatisPlusConfig;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.FilterType;
import javax.sql.DataSource;
@Configuration
@MapperScan({"com.xxx.mapper"})
@ComponentScan(value = "com.xxx.config", includeFilters = {
@ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {MybatisPlusConfig.class})
}, useDefaultFilters = false)
public class TestMybatisPlusConfig {
@Autowired
private InsertBatchSqlInjector insertBatchSqlInjector;
@Bean
public MybatisSqlSessionFactoryBean mybatisSqlSessionFactoryBean() throws Exception {
MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean();
sqlSessionFactoryBean.setDataSource(dataSource());
sqlSessionFactoryBean.setGlobalConfig(globalConfig());
return sqlSessionFactoryBean;
}
@Bean
public SqlSession sqlSession() throws Exception {
return mybatisSqlSessionFactoryBean().getObject().openSession(true);
}
@Bean
public GlobalConfig globalConfig() throws Exception {
GlobalConfig globalConfig = new GlobalConfig();
globalConfig.setSqlInjector(insertBatchSqlInjector);
return globalConfig;
}
@Bean
public DataSource dataSource() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl("jdbc:mysql://dburl:3306/dbname?useSSL=false&characterEncoding=utf-8&serverTimezone=GMT%2B8&allowMultiQueries=true&rewriteBatchedStatements=true");
config.setUsername("username");
config.setPassword("password");
config.setDriverClassName("com.mysql.cj.jdbc.Driver");
return new HikariDataSource(config);
}
}