spring boot mybatis 注解方式抽取无侵入共用dao-@SelectProvider、@UpdateProvider、@InsertProvider、@DeleteProvider

8 篇文章 0 订阅
1 篇文章 0 订阅

前言

使用mybatis开发,如果使用xml的方式,比较繁琐,增删字段的时候比较麻烦;使用注解的方式开发,每次新建bean都需要写CURD的sql,也比较麻烦,所以想到抽取公用的dao,之前写xml比较麻烦,所以就直接使用注解方式开发测试的。
在不引用其他工具类的方式来完成公用dao生成sql。

写bean

这里使用@Table、@Id、@Column的注解来完成对特殊表名、主键、特殊列名的标记,来增加更多适用性。
Table表名和Column名都是默认是按照小写下划线“_”连接来解析的,有特殊的可以在Table和Column注解的name属性上指定。
下面是例子:
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import javax.persistence.Column;
import javax.persistence.Id;
import javax.persistence.Table;
import java.time.LocalDateTime;

/**
 * @author johny
 * @date 2022/8/30
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@Table(name = "t_test_table")
public class TestBean {

    @Id
    private Long id;

    private String name;

    private String address;

    @Column(name = "create_time")
    private LocalDateTime createTime;
    @Column(name = "update_time")
    private LocalDateTime updateTime;

}

baseDao

1.插入方法中的updateFields是用来指定插入时如果发生唯一键冲突,需要更新的字段集合
2.查询时,引入了可变参数fields来指定要查询展示的列名,可以更灵活
3.查询列表时,分页是使用spring boot data的Pageable,in查询是通过map来传入的。
Pageable支持分页参数page,size以及排序参数Sort,sort中可以添加多个order,具体用法请看后面的service中的代码
import com.johny.test.util.MybatisUtil;
import org.apache.ibatis.annotations.*;
import org.springframework.data.domain.Pageable;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * T为对象类型,I为主键类型
 *
 * @author johny
 * @date 2022/9/2
 */
public interface MybatisBaseDao<T, I> {
    /**
     * 保存
     *
     * @param bean 			保存的对象
     * @param updateFields  唯一键冲突时更新的字段
     * @return
     * @throws Exception
     */
    @InsertProvider(type = MybatisUtil.class, method = "insert")
    boolean insert(@Param("bean") T bean,@Param("updateFields") Set<String> updateFields) throws Exception;

    /**
     * 批量保存
     *
     * @param list 			批量保存的对象
     * @param updateFields  唯一键冲突时更新的字段
     * @return
     * @throws Exception
     */
    @InsertProvider(type = MybatisUtil.class, method = "batchInsert")
    void batchInsert(@Param("list") List<T> list, @Param("updateFields") Set<String> updateFields) throws Exception;

    /**
     * 更新通过主键
     *
     * @param bean 对象
     * @return
     * @throws Exception
     */
    @UpdateProvider(type = MybatisUtil.class, method = "updateByPrimaryKey")
    boolean updateByPrimaryKey(@Param("bean") T bean) throws Exception;

    /**
     * 删除通过主键
     *
     * @param primaryKey 主键
     * @return
     * @throws Exception
     */
    @DeleteProvider(type = MybatisUtil.class, method = "deleteByPrimaryKey")
    boolean deleteByPrimaryKey(I primaryKey) throws Exception;

    /**
     * 删除通过bean
     *
     * @param bean	删除条件
     * @return
     * @throws Exception
     */
    @DeleteProvider(type = MybatisUtil.class, method = "deleteByBean")
    boolean deleteByBean(@Param("bean") T bean) throws Exception;

    /**
     * 查询对象
     *
     * @param primaryKey 主键
     * @param fields	 查询字段
     * @return
     * @throws Exception
     */
    @SelectProvider(type = MybatisUtil.class, method = "queryObjectByPrimaryKey")
    T queryObjectByPrimaryKey(@Param("primaryKey") I primaryKey, @Param("fields") String... fields) throws Exception;

    /**
     * 查询list
     *
     * @param bean 		条件
     * @param inMap 	in查询
     * @param pageable  分页参数
     * @param fields 	展示字段
     * @return
     * @throws Exception
     */
    @SelectProvider(type = MybatisUtil.class, method = "queryListByBean")
    List<T> queryListByBean(@Param("bean") T bean, @Param("inMap") Map<String, List<String>> inMap,
                            @Param("pageable") Pageable pageable, @Param("fields") String... fields) throws Exception;

	/**
     * 查询list
     *
     * @param conditions
     * @param inMap
     * @param pageable
     * @param fields
     * @return
     * @throws Exception
     */
    @SelectProvider(type = MybatisUtil.class, method = "queryList")
    List<T> queryList(@Param("conditions") Map<String,Map<String, Object>> conditions, @Param("inMap") Map<String, List<String>> inMap,
                      @Param("pageable") Pageable pageable, @Param("fields") String... fields) throws Exception;
}

提供Sql的Provider

1.使用的是ibatis自带的SQL对象来完成sql拼接的
2.表名和列表都是通过类的反射获取到然后根据驼峰式命名分割使用_连接的,有特殊的可以Table和Column注解的name属性上指定
import org.apache.ibatis.builder.annotation.ProviderContext;
import org.apache.ibatis.jdbc.SQL;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.util.CollectionUtils;

import javax.persistence.Column;
import javax.persistence.Id;
import javax.persistence.Table;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;


/**
 * mybatis工具类
 *
 * @author johny
 * @date 2022/9/2
 */
public class MybatisUtil{
	/** 正则匹配 **/
    private final static Pattern HUMP_PATTERN = Pattern.compile("[A-Z]");

    /**
     * 新增
     *
     * @param map
     * @param context mybatis自动注入的对象
     * @return
     * @throws Exception
     */
    public String insert(Map<String, Object> map, ProviderContext context) throws Exception {
        Object o = map.get("bean");
        List<Object> list = new ArrayList<>();
        list.add(o);
        map.put("list", list);
        return batchInsert(map, context);
    }

    /**
     * 批量新增
     *
     * @param map
     * @return
     * @throws Exception
     */
    public String batchInsert(Map<String, Object> map, ProviderContext context) throws Exception {
        if (map.get("list") == null) {
            return null;
        }
        List<Object> list = (List<Object>) map.get("list");
        if (CollectionUtils.isEmpty(list)) {
            return null;
        }
        Set<String> updateFields = null;
        if (map.get("updateFields") != null) {
            updateFields = (Set<String>) map.get("updateFields");
        }
        StringBuilder updateFieldSql = new StringBuilder();
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        sql.INSERT_INTO(getTableName(beanClass));
        fillInsert(list, sql, updateFields, updateFieldSql);
        String toString = sql.toString();
        if (updateFieldSql.length() > 0) {
            updateFieldSql.deleteCharAt(updateFieldSql.length() - 1);
            toString += " ON DUPLICATE KEY UPDATE " + updateFieldSql;
        }
        return toString;
    }

    /**
     * 更新
     *
     * @param map
     * @return
     * @throws Exception
     */
    public String updateByPrimaryKey(Map<String, Object> map, ProviderContext context) throws Exception {
        Object bean = map.get("bean");
        SQL sql = new SQL();
        sql.UPDATE(getTableName(bean));
        fillFields(bean, sql, 2);
        return sql.toString();
    }

    /**
     * 删除(根据主键)
     *
     * @param context
     * @return
     * @throws Exception
     */
    public String deleteByPrimaryKey(ProviderContext context) throws Exception {
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        sql.DELETE_FROM(getTableName(beanClass));
        sql.WHERE(getPrimaryKey(beanClass) + "=#{primaryKey}");
        return sql.toString();
    }

    /**
     * 删除(根据对象)
     *
     * @param context
     * @return
     * @throws Exception
     */
    public String deleteByBean(Map<String, Object> map, ProviderContext context) throws Exception {
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        sql.DELETE_FROM(getTableName(beanClass));
        fillFields(map.get("bean"), sql, 4);
        return sql.toString();
    }

    /**
     * 查询根据主键
     *
     * @param map
     * @return
     * @throws Exception
     */
    public String queryObjectByPrimaryKey(Map<String, Object> map, ProviderContext context) throws Exception {
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
        sql.FROM(getTableName(beanClass));
        sql.WHERE(getPrimaryKey(beanClass) + "=#{primaryKey}");
        return sql.toString();
    }

    /**
     * 查询列表
     *
     * @param map
     * @return
     * @throws Exception
     */
    public String queryListByBean(Map<String, Object> map, ProviderContext context) throws Exception {
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
        sql.FROM(getTableName(beanClass));
        fillFields(map.get("bean"), sql, 3);
        if (map.get("inMap") != null) {
            fillInConditions(map, sql);
        }
        if (map.get("pageable") != null) {
            fillPage(map, sql);
        }
        return sql.toString();
    }
    
	/**
     * 更新
     *
     * @param map
     * @return
     * @throws Exception
     */
    public String queryList(Map<String, Object> map, ProviderContext context) throws Exception {
        Class beanClass = getClass(context);
        SQL sql = new SQL();
        fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
        sql.FROM(getTableName(beanClass));
        if (map.get("conditions") != null) {
            fillConditions(map, sql);
        }
        if (map.get("inMap") != null) {
            fillInConditions(map, sql);
        }
        if (map.get("pageable") != null) {
            fillPage(map, sql);
        }
        return sql.toString();
    }

    /**
     * 填充分页条件
     *
     * @param map
     * @param sql
     */
    private void fillPage(Map<String, Object> map, SQL sql) {
        Pageable pageable = (Pageable) map.get("pageable");
        if (pageable.getSort().isSorted()) {
            for (Sort.Order order : pageable.getSort().toList()) {
                sql.ORDER_BY(order.getProperty() + " " + order.getDirection());
            }
        }
        sql.OFFSET(pageable.getOffset()).LIMIT(pageable.getPageSize());
    }

	/**
     * 填充in条件
     *
     * @param map
     * @param sql
     */
    private void fillConditions(Map<String, Object> map, SQL sql) {
        Map<String, Map<String,Object>> conditions = (Map<String, Map<String,Object>>) map.get("conditions");
        for (Map.Entry<String, Map<String,Object>> entry : conditions.entrySet()) {
            Map<String,Object> value = entry.getValue();
            String operate = value.get("operate").toString();
            sql.WHERE(entry.getKey() + operate +" #{conditions." + entry.getKey() + ".value}");
        }
    }

    /**
     * 填充in条件
     *
     * @param map
     * @param sql
     */
    private void fillInConditions(Map<String, Object> map, SQL sql) {
        Map<String, List<String>> inMap = (Map<String, List<String>>) map.get("inMap");
        for (Map.Entry<String, List<String>> entry : inMap.entrySet()) {
            StringBuilder builder = new StringBuilder();
            //entry.getValue().forEach(str -> builder.append("'").append(str).append("'").append(","));
            List<String> value = entry.getValue();
            for (int i = 0; i < value.size(); i++) {
                builder.append("#{inMap." + entry.getKey() + "[" + i + "]}").append(",");
            }
            builder.deleteCharAt(builder.length() - 1);
            sql.WHERE(entry.getKey() + " in (" + builder + ")");
        }
    }

    /**
     * 获取主键列表
     *
     * @param c
     * @return
     */
    private String getPrimaryKey(Class c) {
        Field[] fields = c.getDeclaredFields();
        for (Field field : fields) {
            Id idColumn = field.getAnnotation(Id.class);
            if (idColumn != null) {
                return field.getName();
            }
        }
        return null;
    }

    /**
     * 获取字段并填充sql
     *
     * @param o
     * @param sql
     * @param sqlType 1添加,2更新,3查询,4删除
     * @return
     */
    private void fillFields(Object o, SQL sql, int sqlType) throws Exception {
        if (o != null) {
            Field[] fields = o.getClass().getDeclaredFields();
            for (Field field : fields) {
                field.setAccessible(true);
                String name = field.getName();
                String columnName = getColumnName(field);
                Id idColumn = field.getAnnotation(Id.class);
                if (idColumn != null) {
                    if (sqlType == 2 || sqlType == 3 || sqlType == 4) {
                        //更新
                        sql.WHERE(columnName + "=#{bean." + name + "}");
                    }
                    continue;
                }
                Object o1 = field.get(o);
                if (o1 != null) {
                    if (sqlType == 2) {
                        //更新
                        sql.SET(columnName + "=#{bean." + name + "}");
                    } else if (sqlType == 3 || sqlType == 4) {
                        //查询
                        sql.WHERE(columnName + "=#{bean." + name + "}");
                    }

                }
            }
        }
    }

    /**
     * 填充新增sql
     *
     * @param list
     * @param sql
     * @param updateFields 	 冲突时更新的字段
     * @param updateFieldSql 填充后的更新sql
     * @throws Exception
     */
    private void fillInsert(List<Object> list, SQL sql, Set<String> updateFields, StringBuilder updateFieldSql) throws Exception {
        //value的格式
        StringBuilder pattern = new StringBuilder();
        //列
        Object o = list.get(0);
        Field[] fields = o.getClass().getDeclaredFields();
        for (Field field : fields) {
            Id idColumn = field.getAnnotation(Id.class);
            if (idColumn != null) {
                continue;
            }
            field.setAccessible(true);
            String name = field.getName();
            Object o1 = field.get(o);
            if (o1 != null) {
                String columnName = getColumnName(field);
                sql.INTO_COLUMNS(columnName);
                pattern.append("#{list[%d].").append(name).append("},");
                if (!CollectionUtils.isEmpty(updateFields) && updateFields.contains(columnName)) {
                    updateFieldSql.append(columnName + "=VALUES(" + columnName + "),");
                }
            }
        }
        if (pattern.length() > 0) {
            pattern.deleteCharAt(pattern.length() - 1);
            for (int i = 0; i < list.size(); i++) {
                if (i > 0) {
                    sql.ADD_ROW();
                }
                String format = pattern.toString().replaceAll("%d", i + "");
                sql.INTO_VALUES(format);
            }
        }
    }

    /**
     * 获取列名
     *
     * @param field
     * @return
     */
    private String getColumnName(Field field) {
        String columnName;
        //判断是否有特殊注解
        Column annotation = field.getDeclaredAnnotation(Column.class);
        if (annotation != null) {
            columnName = annotation.name();
        } else {
            columnName = humpToLine2(field.getName());
        }
        return columnName;
    }

    /**
     * 获取查询展示字段
     *
     * @param o
     * @return
     */
    private void fillSelectFields(Class o, SQL sql, Set<String> selectFields) {
        Field[] fields = o.getDeclaredFields();
        for (Field field : fields) {
            String columnName = getColumnName(field);
            if (!CollectionUtils.isEmpty(selectFields)) {
                if (selectFields.contains(columnName)) {
                    sql.SELECT(columnName + " AS " + field.getName());
                }
            } else {
                sql.SELECT(columnName + " AS " + field.getName());
            }
        }
        //字段错误抛错
        if (StringUtils.isBlank(sql.toString())){
            throw new RuntimeException("无有效字段");
        }
    }

    private Class getClass(ProviderContext context){
        Class mClass = context.getMapperType();
        return (Class) ((ParameterizedType) (mClass.getGenericInterfaces()[0])).getActualTypeArguments()[0];
    }

    /**
     * 获取表名
     *
     * @param c
     * @return
     */
    private String getTableName(Class c){
        Annotation[] annotations = c.getDeclaredAnnotations();
        for (Annotation annotation : annotations) {
            if (annotation instanceof Table){
                return ((Table) annotation).name();
            }
        }
        String className = c.getSimpleName();
        return humpToLine2(className);
    }

    /**
     * 获取表名
     *
     * @param o
     * @return
     */
    private String getTableName(Object o){
        Table annotation = o.getClass().getAnnotation(Table.class);
        if (annotation != null){
            return annotation.name();
        }
        String className = o.getClass().getSimpleName();
        return humpToLine2(className);
    }

    /**
     * 将属性名根据驼峰转换为_连接的列名
     *
     * @param str
     * @return
     */
    private static String humpToLine2(String str) {
        Matcher matcher = HUMP_PATTERN.matcher(str);
        StringBuffer sb = new StringBuffer();
        while (matcher.find()) {
            matcher.appendReplacement(sb, "_" + matcher.group(0).toLowerCase());
        }
        matcher.appendTail(sb);
        return sb.toString();
    }

}

测试dao

import com.johny.test.bean.TestBean;
import com.johny.test.dao.MybatisBaseDao;
import com.johny.test.util.MybatisUtil;
import org.apache.ibatis.annotations.*;

import java.util.List;
import java.util.Map;

/**
 * 测试dao
 *
 * @author johny
 * @date 2022/8/30
 */
public interface TestDao extends MybatisBaseDao<TestBean,Long> {
}

测试service

下面是使用Pageable的示例代码,可以参考一下
/**
 * @author johny
 * @date 2022/9/3
 */
@Service
@Lazy
public class UpdateUnionIdService {
	@Autowired
    TestDao testDao;
    
	/**
	 *测试查询
	 */
	public void query() throws Exception {
		//查询条件
        TestBean testBean = new TestBean();
        testBean.setId(16021L);
        //排序条件
        Sort.Order orderId = new Sort.Order(Sort.Direction.ASC, "id");
        Sort.Order orderOpenid = new Sort.Order(Sort.Direction.DESC, "name");
        Sort sort = Sort.by(orderId, orderOpenid);
        //分页
        Pageable page=PageRequest.of(0, 10, sort)
        //in条件
        Map<String, List<String>> inMap = new HashMap<>();
        List<String> strings = new ArrayList<>();
        strings.add("测试");
        strings.add("4");
        inMap.put("name", strings);
        //查询 create_time为查询展示字段
        List<TestBean> beans = testDao.queryListByBean(null, inMap, page, "create_time");
        if (!CollectionUtils.isEmpty(beans)) {
            System.out.println(JSON.toJSONString(beans));
        }
	
		Map<String, Map<String,Object>> conditions=new HashMap<>();
        Map<String,Object> condition=new HashMap<>();
        condition.put("operate",">");
        condition.put("value",0);
        conditions.put("id",condition);

        Map<String,Object> condition2=new HashMap<>();
        condition2.put("operate","<");
        condition2.put("value",10);
        conditions.put("id",condition2);
        List<TestBean> beans = testDao.queryList(conditions, null, page, "create_time");
    }
}

最后

mybatis设置map-underscore-to-camel-case最好为true

以上就是我自己抽取的公用方法,没有写更复杂的查询,我想了一下,再引入条件工具类和更新工具类的话就不如写复杂查询的时候直接@select写sql来的方便了,所以没继续扩展。


2024年1月27日更新
1.增加字段错误抛异常
2.增加支持普通where条件方法
希望大家可以按照示例写更多适合自己扩展方法

欢迎大家批评指正
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值