MyBatis增强-实现通用的基础查询

1. 设想

设计一个通用的基础Mapper,仅通过接口的继承实现每个表的基础查询(增删改查)。

2.引入MyBatis并配置

2.1 引入相关Maven依赖
  • MyBatis
<dependency>
    <groupId>org.mybatis.spring.boot</groupId>
    <artifactId>mybatis-spring-boot-starter</artifactId>
    <version>2.1.1</version>
</dependency>
  • MySQL JDBC驱动
<dependency>
    <groupId>mysql</groupId>
    <artifactId>mysql-connector-java</artifactId>
    <version>5.1.44</version>
</dependency>
  • 数据库连接池
<dependency>
	<groupId>com.alibaba</groupId>
    <artifactId>druid</artifactId>
    <version>1.0.28</version>
</dependency>
2.2 配置数据源
spring:
  datasource:
    type: com.alibaba.druid.pool.DruidDataSource
    driver-class-name: com.mysql.jdbc.Driver
    url: jdbc:mysql://www.xxx.com:3306/db?useUnicode=yes&characterEncoding=utf-8&allowMultiQueries=true&useSSL=false
    username: abc
    password: 123456
2.3 MyBatis配置
mybatis:
  config-location: classpath:mybatis/mybatis-config.xml
  mapper-locations: classpath:mybatis/mapper/*.xml
  type-aliases-package: com.vz.mybatis.enhance.entity
2.4 Mybatis配置文件mybatis-config.xml
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE configuration PUBLIC "-//mybatis.org//DTD Config 3.0//EN" "http://mybatis.org/dtd/mybatis-3-config.dtd">
<configuration>
    <settings>
        <setting name="callSettersOnNulls" value="true"/>

        <setting name="cacheEnabled" value="true"/>

        <setting name="lazyLoadingEnabled" value="true"/>

        <setting name="aggressiveLazyLoading" value="true"/>

        <setting name="multipleResultSetsEnabled" value="true"/>

        <setting name="useColumnLabel" value="true"/>

        <setting name="useGeneratedKeys" value="false"/>

        <setting name="autoMappingBehavior" value="PARTIAL"/>

        <setting name="defaultExecutorType" value="SIMPLE"/>

        <!-- 下划线转驼峰命名,一定要开启-->
        <setting name="mapUnderscoreToCamelCase" value="true"/>

        <setting name="localCacheScope" value="SESSION"/>

        <setting name="jdbcTypeForNull" value="NULL"/>

        <setting name="logImpl" value="NO_LOGGING" />
		<!-- <setting name="logImpl" value="STDOUT_LOGGING" />-->

    </settings>

    <typeAliases>
        <typeAlias alias="Integer" type="java.lang.Integer" />
        <typeAlias alias="Long" type="java.lang.Long" />
        <typeAlias alias="HashMap" type="java.util.HashMap" />
        <typeAlias alias="LinkedHashMap" type="java.util.LinkedHashMap" />
        <typeAlias alias="ArrayList" type="java.util.ArrayList" />
        <typeAlias alias="LinkedList" type="java.util.LinkedList" />
        <package name="com.vz.mybatis.enhance.entity"/>
    </typeAliases>
    
</configuration>

3.通用Mapper定义:BaseMapper

定义通用的基础SQL操作方法

package com.vz.mybatis.enhance.common.mapper.core;

import com.vz.mybatis.enhance.common.mapper.qr.Querier;
import org.apache.ibatis.annotations.*;

import java.util.Collection;
import java.util.List;

/**
 * @author visy.wang
 * @description: 基础通用Mapper
 * @date 2023/4/24 12:59
 * <T> 数据库表对应实体类型
 * <K> 主键类型
 */
public interface BaseMapper<T,K>{
    /**
     * 按主键查询记录
     * @param id 主键
     * @return 记录
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "selectById")
    T selectById(@Param(BaseSqlProvider.ID_NAME) K id);

    /**
     * 按主键列表批量查询记录
     * @param idList 主键列表
     * @return 记录列表
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "selectByIds")
    List<T> selectByIds(@Param(BaseSqlProvider.IDS_NAME) Collection<K> idList);

    /**
     * 按条件查询一条记录(多条记录时,自动取第一条)
     * @param querier 查询条件
     * @return 记录
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "selectOne")
    T selectOne(@Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);

    /**
     * 按条件查询记录列表
     * @param querier 查询条件
     * @return 记录列表
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "selectList")
    List<T> selectList(@Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);

    /**
     * 查询所有记录列表
     * @return 记录列表
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "selectAll")
    List<T> selectAll();

    /**
     * 按条件查询记录数
     * @param querier 查询条件
     * @return 记录数
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "count")
    long count(@Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);

    /**
     * 查询所有记录总数
     * @return 记录数
     */
    @SelectProvider(type = BaseSqlProvider.class, method = "countAll")
    long countAll();

    /**
     * 按主键删除记录
     * @param id 主键
     * @return 删除成功数量
     */
    @DeleteProvider(type = BaseSqlProvider.class, method = "deleteById")
    int deleteById(@Param(BaseSqlProvider.ID_NAME) K id);

    /**
     * 按主键列表批量删除记录
     * @param idList 主键列表
     * @return 删除成功数量
     */
    @DeleteProvider(type = BaseSqlProvider.class, method = "deleteByIds")
    int deleteByIds(@Param(BaseSqlProvider.IDS_NAME) Collection<K> idList);

    /**
     * 按条件删除记录
     * @param querier 筛选条件
     * @return 删除成功数量
     */
    @DeleteProvider(type = BaseSqlProvider.class, method = "delete")
    int delete(@Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);

    /**
     * 新增一条记录(包含为null的字段)
     * @param record 记录信息
     * @return 新增成功数量
     */
    @InsertProvider(type = BaseSqlProvider.class, method = "insert")
    int insert(@Param(BaseSqlProvider.ENTITY_NAME) T record);

    /**
     * 新增一条记录(不包含为null的字段)
     * @param record 记录信息
     * @return 新增成功数量
     */
    @InsertProvider(type = BaseSqlProvider.class, method = "insertSelective")
    int insertSelective(@Param(BaseSqlProvider.ENTITY_NAME) T record);

    /**
     * 按主键更新记录(包含为null的字段)
     * @param record 待更新记录信息(主键值不能为null)
     * @return 更新成功数量
     */
    @UpdateProvider(type = BaseSqlProvider.class, method = "updateById")
    int updateById(@Param(BaseSqlProvider.ENTITY_NAME) T record);

    /**
     * 按主键更新记录(不包含为null的字段)
     * @param record 待更新记录信息(主键值不能为null)
     * @return 更新成功数量
     */
    @UpdateProvider(type = BaseSqlProvider.class, method = "updateByIdSelective")
    int updateByIdSelective(@Param(BaseSqlProvider.ENTITY_NAME) T record);

    /**
     * 按条件更新记录(包含record中为null的字段)
     * @param record 待更新记录信息
     * @param querier 更新条件
     * @return 更新成功数量
     */
    @UpdateProvider(type = BaseSqlProvider.class, method = "update")
    int update(@Param(BaseSqlProvider.ENTITY_NAME) T record, @Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);

    /**
     * 按条件更新记录(不包含record中为null的字段)
     * @param record 待更新记录信息
     * @param querier 更新条件
     * @return 更新成功数量
     */
    @UpdateProvider(type = BaseSqlProvider.class, method = "updateSelective")
    int updateSelective(@Param(BaseSqlProvider.ENTITY_NAME) T record, @Param(BaseSqlProvider.QUERIER_NAME) Querier<T> querier);
}

4.定义SQL生成器:BaseSqlProvider

为BaseMapper中的方法提供SQL,定义了每个SQL语句的生成规则
且采用预定义的方式(防止SQL注入)

package com.vz.mybatis.enhance.common.mapper.core;

import com.vz.mybatis.enhance.common.mapper.hp.MapperHelper;
import com.vz.mybatis.enhance.common.mapper.hp.SqlHelper;
import com.vz.mybatis.enhance.common.mapper.inf.COLUMN_INF;
import com.vz.mybatis.enhance.common.mapper.inf.TABLE_INF;
import com.vz.mybatis.enhance.common.mapper.qr.BaseExample;
import com.vz.mybatis.enhance.common.mapper.qr.Criterion;
import com.vz.mybatis.enhance.common.mapper.qr.Querier;
import org.apache.ibatis.builder.annotation.ProviderContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.util.*;

/**
 * @author visy.wang
 * @description: 基础SQL生成
 * @date 2023/4/24 13:18
 */
public class BaseSqlProvider {
    private static final Logger logger = LoggerFactory.getLogger(BaseSqlProvider.class);
    public static final String ID_NAME = "id";
    public static final String IDS_NAME = "ids";
    public static final String ENTITY_NAME = "record";
    public static final String QUERIER_NAME = "querier";

    public String selectById(Map<String,Object> params, ProviderContext context){
        TABLE_INF table = MapperHelper.getTable(context);
        COLUMN_INF pkColumn = table.getPkColumn();
        return SqlHelper.sql()
                .select(table.allColumns())
                .from(table.getTableName())
                .where(pkColumn.getColumn() + " = #{"+ID_NAME+"}")
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String selectByIds(Map<String,Object> params, ProviderContext context){
        Object idList = params.get(IDS_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        COLUMN_INF pkColumn = table.getPkColumn();
        return SqlHelper.sql()
                .select(table.allColumns())
                .from(table.getTableName())
                .where(getInCondition(pkColumn.getColumn(), IDS_NAME, idList))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String selectOne(Map<String,Object> params, ProviderContext context){
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        BaseExample example = querier.getExample();
        return SqlHelper.sql()
                .select(table.allColumns(), example.getDistinct())
                .from(table.getTableName())
                .where(getConditions(example, params))
                .orderBy(example.getOrderByClause())
                .limit("1") //只取查询结果中第一条记录
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String selectList(Map<String,Object> params, ProviderContext context){
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        BaseExample example = querier.getExample();
        return SqlHelper.sql()
                .select(table.allColumns(), example.getDistinct())
                .from(table.getTableName())
                .where(getConditions(example, params))
                .orderBy(example.getOrderByClause())
                .limit(example.getLimitClause())
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String selectAll(ProviderContext context){
        TABLE_INF table = MapperHelper.getTable(context);
        return SqlHelper.sql()
                .select(table.allColumns())
                .from(table.getTableName())
                .toStr(sql -> log(context, sql, null));
    }

    public String count(Map<String,Object> params, ProviderContext context){
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        BaseExample example = querier.getExample();
        return SqlHelper.sql()
                .count(table.getPkColumn().getColumn(), example.getDistinct())
                .from(table.getTableName())
                .where(getConditions(example, params))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String countAll(ProviderContext context){
        TABLE_INF table = MapperHelper.getTable(context);
        return SqlHelper.sql()
                .count(table.getPkColumn().getColumn())
                .from(table.getTableName())
                .toStr(sql -> log(context, sql, null));
    }

    public String deleteById(Map<String,Object> params, ProviderContext context){
        TABLE_INF table = MapperHelper.getTable(context);
        COLUMN_INF pkColumn = table.getPkColumn();
        return SqlHelper.sql()
                .delete()
                .from(table.getTableName())
                .where(pkColumn.getColumn() + " = #{"+ID_NAME+"}")
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String deleteByIds(Map<String,Object> params, ProviderContext context){
        Object idList = params.get(IDS_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        COLUMN_INF pkColumn = table.getPkColumn();
        return SqlHelper.sql()
                .delete()
                .from(table.getTableName())
                .where(getInCondition(pkColumn.getColumn(), IDS_NAME, idList))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String delete(Map<String,Object> params, ProviderContext context){
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        BaseExample example = querier.getExample();
        return SqlHelper.sql()
                .delete()
                .from(table.getTableName())
                .where(getConditions(example, params))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String insert(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        List<String> columns = new ArrayList<>(), values = new ArrayList<>();
        table.getColumns().forEach(item -> {
            if(item.getIsPK()){
                //跳过主键,主键由数据库自增自动产生
                return;
            }
            columns.add(item.getColumn());
            String property = item.getProperty();
            values.add("#{"+property+"}");
            try{
                params.put(property, item.getField().get(entity));
            }catch (Exception e){
                params.put(property, null);
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .insert(table.getTableName())
                .values(columns, values)
                .toStr(sql -> log(context, sql, entity));
    }

    public String insertSelective(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        List<String> columns = new ArrayList<>(), values = new ArrayList<>();
        table.getColumns().forEach(item -> {
            if(item.getIsPK()){
                //跳过主键,主键由数据库自增自动产生
                return;
            }
            try{
                Object value = item.getField().get(entity);
                if(Objects.nonNull(value)){
                    columns.add(item.getColumn());
                    String property = item.getProperty();
                    values.add("#{"+property+"}");
                    params.put(property, value);
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .insert(table.getTableName())
                .values(columns, values)
                .toStr(sql -> log(context, sql, entity));
    }

    public String updateById(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        Map<String,String> setValues = new HashMap<>();
        StringBuilder condition = new StringBuilder();
        table.getColumns().forEach(item -> {
            try{
                String column = item.getColumn(), property = item.getProperty();
                Object value = item.getField().get(entity);
                params.put(property, value);
                if(item.getIsPK()){
                    if(Objects.isNull(value)){
                        //主键值不能为空
                        throw new IllegalArgumentException("The primary key '"+property+"' can not be null !");
                    }
                    condition.append(column).append("=").append("#{").append(property).append("}");
                }else{
                    setValues.put(column, "#{"+property+"}");
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .update(table.getTableName())
                .set(setValues)
                .where(condition.toString())
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String updateByIdSelective(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        TABLE_INF table = MapperHelper.getTable(context);
        Map<String,String> setValues = new HashMap<>();
        StringBuilder condition = new StringBuilder();
        table.getColumns().forEach(item -> {
            try{
                String column = item.getColumn(), property = item.getProperty();
                Object value = item.getField().get(entity);
                if(item.getIsPK()){
                    if(Objects.isNull(value)){
                        //主键值不能为空
                        throw new IllegalArgumentException("The primary key '"+property+"' can not be null !");
                    }
                    params.put(property, value);
                    condition.append(column).append("=").append("#{").append(property).append("}");
                }else if(Objects.nonNull(value)){
                    params.put(property, value);
                    setValues.put(column, "#{"+property+"}");
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .update(table.getTableName())
                .set(setValues)
                .where(condition.toString())
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String update(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        BaseExample example = querier.getExample();
        TABLE_INF table = MapperHelper.getTable(context);
        Map<String,String> setValues = new HashMap<>();
        table.getColumns().forEach(item -> {
            if(item.getIsPK()){
                //主键不能修改
                return;
            }
            try{
                String property = item.getProperty()+"Alias"; //避免和条件中的属性名冲突
                params.put(property, item.getField().get(entity));
                setValues.put(item.getColumn(), "#{"+property+"}");
            }catch (Exception e){
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .update(table.getTableName())
                .set(setValues)
                .where(getConditions(example, params))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    public String updateSelective(Map<String,Object> params, ProviderContext context){
        Object entity = params.get(ENTITY_NAME);
        Querier<?> querier = (Querier<?>)params.get(QUERIER_NAME);
        BaseExample example = querier.getExample();
        TABLE_INF table = MapperHelper.getTable(context);
        Map<String,String> setValues = new HashMap<>();
        table.getColumns().forEach(item -> {
            if(item.getIsPK()){
                //主键不能修改
                return;
            }
            try{
                Object value = item.getField().get(entity);
                if(Objects.nonNull(value)){
                    String property = item.getProperty()+"Alias"; //避免和条件中的属性名冲突
                    params.put(property, value);
                    setValues.put(item.getColumn(), "#{"+property+"}");
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        });

        return SqlHelper.sql()
                .update(table.getTableName())
                .set(setValues)
                .where(getConditions(example, params))
                .toStr(sql -> log(context, sql, washing(params)));
    }

    private static String getConditions(BaseExample example, Map<String,Object> params){
        StringBuilder condition = new StringBuilder();
        example.getCriteriaList().forEach(criteria -> {
            for (Criterion cri : criteria.getAllCriteria()) {
                if(condition.length() > 0){
                    condition.append(" AND ");
                }

                String property = cri.getProperty();
                Object value = cri.getValue();

                if(cri.isNoValue()){
                    condition.append(cri.getCondition());
                }else if(cri.isSingleValue()){
                    params.put(property, value);
                    condition.append(cri.getCondition()).append("#{").append(property).append("}");
                }else if(cri.isListValue()){
                    params.put(property, value);
                    String inSequence = getInSequence(property, value);
                    condition.append(cri.getCondition()).append("(").append(inSequence).append(")");
                }else if(cri.isBetweenValue()){
                    String property1 = property+"1", property2 = property+"2";
                    Object secondValue = cri.getSecondValue();
                    params.put(property1, value);
                    params.put(property2, secondValue);
                    condition.append(cri.getCondition())
                            .append("#{").append(property1).append("}")
                            .append(" AND ")
                            .append("#{").append(property2).append("}");
                }else{
                    condition.append(cri.getCondition()).append("NULL");
                }
            }
        });
        return condition.toString();
    }

    private static String getInCondition(String column, String property, Object coll){
        return column + " IN (" + getInSequence(property, coll) + ")";
    }

    private static String getInSequence(String property, Object coll){
        if(coll instanceof Collection){
            Collection<?> collection = (Collection<?>) coll;
            if(!CollectionUtils.isEmpty(collection)){
                StringBuilder inSequence = new StringBuilder();
                for (int i=0; i<collection.size(); i++) {
                    inSequence.append("#{").append(property).append("[").append(i).append("]},");
                }
                return inSequence.deleteCharAt(inSequence.length()-1).toString();
            }
        }
        return "";
    }

    private static Map<String,Object> washing(Map<String,Object> params){
        Arrays.asList("param1", QUERIER_NAME, ENTITY_NAME).forEach(params::remove);
        return params;
    }

    private static void log(ProviderContext context, String sql, Object params){
        String mapperMethodName = context.getMapperType().getName()+"."+context.getMapperMethod().getName();
        logger.info("\nMethod: {}\nSQL: {}\nParams: {}", mapperMethodName, sql, params==null?"{}":params);
    }
}

5.定义一个Mapper发现器:MapperDiscoverer

主要用来实现Insert 语句的时候主键自动回写到实体内

package com.vz.mybatis.enhance.common.mapper.core;

import com.vz.mybatis.enhance.common.mapper.hp.MapperHelper;
import com.vz.mybatis.enhance.common.mapper.inf.COLUMN_INF;
import org.apache.ibatis.executor.keygen.Jdbc3KeyGenerator;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.Configuration;
import org.mybatis.spring.SqlSessionTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author visy.wang
 * @description: Mapper发现器
 * @date 2023/4/28 13:43
 */
@Component
public class MapperDiscoverer implements ApplicationListener<ContextRefreshedEvent> {
    private static final Logger logger = LoggerFactory.getLogger(MapperDiscoverer.class);
    private static Field keyColumnsField, keyPropertiesField, keyGeneratorField;

    @Autowired
    private SqlSessionTemplate sqlSessionTemplate;

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        ApplicationContext applicationContext = event.getApplicationContext();
        if(Objects.nonNull(applicationContext.getParent())){
            return;
        }
        @SuppressWarnings("rawtypes")
        Map<String, BaseMapper> mappers = applicationContext.getBeansOfType(BaseMapper.class);
        Map<String, List<MappedStatement>> insertMappedStatements = getInsertMappedStatements();
        mappers.forEach((name, proxyMapper) -> {
            Class<?> mapperType = proxyMapper.getClass().getInterfaces()[0];
            //获取对应Mapper的Insert语句列表
            List<MappedStatement> mappedStatements = insertMappedStatements.get(mapperType.getName());
            if(CollectionUtils.isEmpty(mappedStatements)){
                return;
            }
            COLUMN_INF pkColumn = MapperHelper.getTable(mapperType).getPkColumn();
            String keyProperty = BaseSqlProvider.ENTITY_NAME+"."+pkColumn.getProperty();
            mappedStatements.forEach(statement -> {
                //修改Insert语句的配置,实现主键的回写
                modifyMappedStatement(statement, pkColumn.getColumn(), keyProperty);
            });
        });
    }

    /**
     * 获取MyBatis中所有已注册的Insert语句的配置
     * @return <Mapper类全路径,Insert语句列表>
     */
    private Map<String,List<MappedStatement>> getInsertMappedStatements(){
        Configuration configuration = sqlSessionTemplate.getConfiguration();
        Collection<MappedStatement> mappedStatements = configuration.getMappedStatements();
        return mappedStatements.stream().filter(statement -> {
            return SqlCommandType.INSERT.equals(statement.getSqlCommandType());
        }).collect(Collectors.groupingBy(statement -> {
            String statementId = statement.getId();
            return statementId.substring(0, statementId.lastIndexOf("."));
        }));
    }

    /**
     * 利用反射修改配置
     * 相当于添加了注解: @Options(useGeneratedKeys = true, keyColumn = "keyColumn", keyProperty = "keyProperty")
     * @param statement SQL语句,对应Mapper的一个方法
     * @param keyColumn 主键在数据库的名称
     * @param keyProperty 主键在实体对象的名称
     */
    private static void modifyMappedStatement(MappedStatement statement, String keyColumn, String keyProperty){
        String[] keyColumns = statement.getKeyColumns();
        if(Objects.nonNull(keyColumns) && keyColumns.length>0){
            //已经设置过则忽略
            return;
        }
        String statementId = statement.getId();
        try{
            if(Objects.isNull(keyColumnsField)){
                Class<?> statementClass = statement.getClass();
                keyColumnsField = statementClass.getDeclaredField("keyColumns");
                keyPropertiesField = statementClass.getDeclaredField("keyProperties");
                keyGeneratorField = statementClass.getDeclaredField("keyGenerator");
                AccessibleObject[] accessibleObjects = {keyColumnsField, keyPropertiesField, keyGeneratorField};
                Field.setAccessible(accessibleObjects, true);
            }
            keyColumnsField.set(statement, new String[]{keyColumn});
            keyPropertiesField.set(statement, new String[]{keyProperty});
            keyGeneratorField.set(statement, Jdbc3KeyGenerator.INSTANCE);
            logger.info("Mapped statement modify success, keyColumn: {}, keyProperty: {}, path: {}", keyColumn, keyProperty, statementId);
        }catch (Exception e){
            logger.info("Mapped statement modify failure, error: {}, path: {}", e.getMessage(), statementId);
        }
    }
}

6.使用:

1.定义一个自己Mapper: UserMapper ,继承BaseMapper, 可添加自定义的其他SQL;
2.定义一个实体:TSupplierUser , 和数据库表t_supplier_user的字段一一对应,一般用MyBatis Generator自定生成 ;
3.注入UserMapper并使用;

6.1 UserMapper :
package com.vz.mybatis.enhance.mapper;

import com.vz.mybatis.enhance.common.mapper.core.BaseMapper;
import com.vz.mybatis.enhance.entity.TSupplierUser;

/**
 * @author visy.wang
 * @description: 自定义Mapper
 * @date 2023/4/24 14:17
 */
public interface UserMapper extends BaseMapper<TSupplierUser, Long> {
	//可实现别的自定义SQL
}
6.2 TSupplierUser :
package com.vz.mybatis.enhance.entity;

import lombok.Data;

import java.util.Date;

/**
 * @author visy.wang
 * @description:
 * @date 2023/4/24 14:43
 */
@Data
public class TSupplierUser {
    private Long userId;

    private String userName;

    private String phone;

    private Integer status;

    private Long enterpriseId;

    private Date createDt;

    private String password;
}
6.3 SupplierUserController :
package com.vz.mybatis.enhance.controller;

import com.vz.mybatis.enhance.common.mapper.qr.Querier;
import com.vz.mybatis.enhance.entity.TSupplierUser;
import com.vz.mybatis.enhance.mapper.UserMapper;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author visy.wang
 * @description: 用户接口(测试)
 * @date 2023/4/24 18:09
 */
@RequestMapping("/user")
@RestController
public class SupplierUserController {
    @Resource
    private UserMapper userMapper;

    @RequestMapping("/get/{id}")
    public Map<String,Object> getById(@PathVariable("id") Long id){
        TSupplierUser user = userMapper.selectById(id);
        return resp(0, user);
    }

    @RequestMapping("/getList/{ids}")
    public Map<String,Object> getByIds(@PathVariable("ids") String ids){
        List<TSupplierUser> userList = userMapper.selectByIds(Arrays.stream(ids.split(",")).map(Long::valueOf).collect(Collectors.toList()));
        return resp(0, userList);
    }

    @RequestMapping("/list")
    public Map<String,Object> list(@RequestParam Long entId,
                                   @RequestParam(required = false, defaultValue = "20") Integer limit){
        Querier<TSupplierUser> querier = Querier.<TSupplierUser>query()
                .gt(TSupplierUser::getEnterpriseId, entId)
                .limit(limit);

        long total = userMapper.count(querier);
        List<TSupplierUser> supplierUserList = userMapper.selectList(querier);
        return resp(total, supplierUserList);
    }

    @RequestMapping("/getOne")
    public Map<String,Object> getOne(@RequestParam Long entId){
        Querier<TSupplierUser> querier = Querier.<TSupplierUser>query()
                .gt(TSupplierUser::getEnterpriseId, entId).asc(TSupplierUser::getUserId);
        return resp(0, userMapper.selectOne(querier));
    }

    @RequestMapping("/listAll")
    public Map<String,Object> listAll(){
        long all = userMapper.countAll();
        List<TSupplierUser> supplierUserList = userMapper.selectAll();
        return resp(all, supplierUserList);
    }

    @RequestMapping("/del/{id}")
    public Map<String,Object> deleteById(@PathVariable("id") Long id){
        int rows = userMapper.deleteById(id);
        return resp(rows, null);
    }

    @RequestMapping("/delete")
    public Map<String,Object> delete(@RequestParam Long entId){
        int rows = userMapper.delete(Querier.<TSupplierUser>query().eq(TSupplierUser::getEnterpriseId, entId));
        return resp(rows, null);
    }

    @RequestMapping("/add")
    public Map<String,Object> add(@RequestParam(required = false, defaultValue = "1") Integer type){
        Querier<TSupplierUser> querier = Querier.<TSupplierUser>query()
                .desc(TSupplierUser::getUserId)
                .limit(1);

        List<TSupplierUser> supplierUserList = userMapper.selectList(querier);
        Long newId = supplierUserList.get(0).getUserId()+1;

        TSupplierUser user = new TSupplierUser();
        user.setUserName("张三"+newId);
        user.setPhone("19301293031"+newId);
        user.setPassword(UUID.randomUUID().toString().replace("-", ""));
        user.setCreateDt(new Date());
        user.setEnterpriseId(newId);

        int rows = type==1 ? userMapper.insert(user) : userMapper.insertSelective(user);

        return resp(rows, user);
    }

    @RequestMapping("/update")
    public Map<String,Object> update(@RequestParam Long entId,
                                     @RequestParam(required = false, defaultValue = "1") Integer type){
        Querier<TSupplierUser> querier = Querier.<TSupplierUser>query().eq(TSupplierUser::getEnterpriseId, entId);

        TSupplierUser user = new TSupplierUser();
        user.setPassword(UUID.randomUUID().toString().replace("-", ""));
        user.setCreateDt(new Date());

        int rows = type==1 ? userMapper.update(user, querier) : userMapper.updateSelective(user, querier);

        return resp(rows, null);
    }

    @RequestMapping("/upd/{id}")
    public Map<String,Object> updateById(@PathVariable("id") Long id,
                                         @RequestParam(required = false, defaultValue = "1") Integer type){
        TSupplierUser user = new TSupplierUser();
        user.setUserId(id);
        user.setPassword(UUID.randomUUID().toString().replace("-", ""));
        user.setCreateDt(new Date());

        int rows = type==1 ? userMapper.updateById(user) : userMapper.updateByIdSelective(user);

        return resp(rows, null);
    }

    private Map<String,Object> resp(long total, Object data){
        Map<String,Object> mp = new HashMap<>();
        mp.put("rows", total);
        mp.put("data", data);
        return mp;
    }

}

7.项目源码

MyBatis Enhance

8.其他说明

  • 定义实体应该和数据库表一一对应,一般用MyBatis Generator等工具自动生成;
  • 主键不需要指定,通过查询数据库(MySQL)的表定义元数据自动识别主键字段;
  • 目前仅支持单主键的表,联合主键暂不支持(但其实可以扩展现有代码来支持);
  • 查询条件构建器Querier暂不支持OR条件及嵌套的查询,仅支持AND串联的查询;
  • Querier改造自MyBatis的Example查询,Example支持OR及其嵌套,所以理论上Querier也能改造成支持OR的查询器;
  • 当数据表发生变化时,只需更新对应的Java实体即可,其他代码保持不变。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值