前言
使用mybatis开发,如果使用xml的方式,比较繁琐,增删字段的时候比较麻烦;使用注解的方式开发,每次新建bean都需要写CURD的sql,也比较麻烦,所以想到抽取公用的dao,之前写xml比较麻烦,所以就直接使用注解方式开发测试的。
在不引用其他工具类的方式来完成公用dao生成sql。
写bean
这里使用@Table、@Id、@Column的注解来完成对特殊表名、主键、特殊列名的标记,来增加更多适用性。
Table表名和Column名都是默认是按照小写下划线“_”连接来解析的,有特殊的可以在Table和Column注解的name属性上指定。
下面是例子:
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import javax.persistence.Column;
import javax.persistence.Id;
import javax.persistence.Table;
import java.time.LocalDateTime;
/**
* @author johny
* @date 2022/8/30
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
@Table(name = "t_test_table")
public class TestBean {
@Id
private Long id;
private String name;
private String address;
@Column(name = "create_time")
private LocalDateTime createTime;
@Column(name = "update_time")
private LocalDateTime updateTime;
}
baseDao
1.插入方法中的updateFields是用来指定插入时如果发生唯一键冲突,需要更新的字段集合
2.查询时,引入了可变参数fields来指定要查询展示的列名,可以更灵活
3.查询列表时,分页是使用spring boot data的Pageable,in查询是通过map来传入的。
Pageable支持分页参数page,size以及排序参数Sort,sort中可以添加多个order,具体用法请看后面的service中的代码
import com.johny.test.util.MybatisUtil;
import org.apache.ibatis.annotations.*;
import org.springframework.data.domain.Pageable;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* T为对象类型,I为主键类型
*
* @author johny
* @date 2022/9/2
*/
public interface MybatisBaseDao<T, I> {
/**
* 保存
*
* @param bean 保存的对象
* @param updateFields 唯一键冲突时更新的字段
* @return
* @throws Exception
*/
@InsertProvider(type = MybatisUtil.class, method = "insert")
boolean insert(@Param("bean") T bean,@Param("updateFields") Set<String> updateFields) throws Exception;
/**
* 批量保存
*
* @param list 批量保存的对象
* @param updateFields 唯一键冲突时更新的字段
* @return
* @throws Exception
*/
@InsertProvider(type = MybatisUtil.class, method = "batchInsert")
void batchInsert(@Param("list") List<T> list, @Param("updateFields") Set<String> updateFields) throws Exception;
/**
* 更新通过主键
*
* @param bean 对象
* @return
* @throws Exception
*/
@UpdateProvider(type = MybatisUtil.class, method = "updateByPrimaryKey")
boolean updateByPrimaryKey(@Param("bean") T bean) throws Exception;
/**
* 删除通过主键
*
* @param primaryKey 主键
* @return
* @throws Exception
*/
@DeleteProvider(type = MybatisUtil.class, method = "deleteByPrimaryKey")
boolean deleteByPrimaryKey(I primaryKey) throws Exception;
/**
* 删除通过bean
*
* @param bean 删除条件
* @return
* @throws Exception
*/
@DeleteProvider(type = MybatisUtil.class, method = "deleteByBean")
boolean deleteByBean(@Param("bean") T bean) throws Exception;
/**
* 查询对象
*
* @param primaryKey 主键
* @param fields 查询字段
* @return
* @throws Exception
*/
@SelectProvider(type = MybatisUtil.class, method = "queryObjectByPrimaryKey")
T queryObjectByPrimaryKey(@Param("primaryKey") I primaryKey, @Param("fields") String... fields) throws Exception;
/**
* 查询list
*
* @param bean 条件
* @param inMap in查询
* @param pageable 分页参数
* @param fields 展示字段
* @return
* @throws Exception
*/
@SelectProvider(type = MybatisUtil.class, method = "queryListByBean")
List<T> queryListByBean(@Param("bean") T bean, @Param("inMap") Map<String, List<String>> inMap,
@Param("pageable") Pageable pageable, @Param("fields") String... fields) throws Exception;
/**
* 查询list
*
* @param conditions
* @param inMap
* @param pageable
* @param fields
* @return
* @throws Exception
*/
@SelectProvider(type = MybatisUtil.class, method = "queryList")
List<T> queryList(@Param("conditions") Map<String,Map<String, Object>> conditions, @Param("inMap") Map<String, List<String>> inMap,
@Param("pageable") Pageable pageable, @Param("fields") String... fields) throws Exception;
}
提供Sql的Provider
1.使用的是ibatis自带的SQL对象来完成sql拼接的
2.表名和列表都是通过类的反射获取到然后根据驼峰式命名分割使用_连接的,有特殊的可以Table和Column注解的name属性上指定
import org.apache.ibatis.builder.annotation.ProviderContext;
import org.apache.ibatis.jdbc.SQL;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.util.CollectionUtils;
import javax.persistence.Column;
import javax.persistence.Id;
import javax.persistence.Table;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* mybatis工具类
*
* @author johny
* @date 2022/9/2
*/
public class MybatisUtil{
/** 正则匹配 **/
private final static Pattern HUMP_PATTERN = Pattern.compile("[A-Z]");
/**
* 新增
*
* @param map
* @param context mybatis自动注入的对象
* @return
* @throws Exception
*/
public String insert(Map<String, Object> map, ProviderContext context) throws Exception {
Object o = map.get("bean");
List<Object> list = new ArrayList<>();
list.add(o);
map.put("list", list);
return batchInsert(map, context);
}
/**
* 批量新增
*
* @param map
* @return
* @throws Exception
*/
public String batchInsert(Map<String, Object> map, ProviderContext context) throws Exception {
if (map.get("list") == null) {
return null;
}
List<Object> list = (List<Object>) map.get("list");
if (CollectionUtils.isEmpty(list)) {
return null;
}
Set<String> updateFields = null;
if (map.get("updateFields") != null) {
updateFields = (Set<String>) map.get("updateFields");
}
StringBuilder updateFieldSql = new StringBuilder();
Class beanClass = getClass(context);
SQL sql = new SQL();
sql.INSERT_INTO(getTableName(beanClass));
fillInsert(list, sql, updateFields, updateFieldSql);
String toString = sql.toString();
if (updateFieldSql.length() > 0) {
updateFieldSql.deleteCharAt(updateFieldSql.length() - 1);
toString += " ON DUPLICATE KEY UPDATE " + updateFieldSql;
}
return toString;
}
/**
* 更新
*
* @param map
* @return
* @throws Exception
*/
public String updateByPrimaryKey(Map<String, Object> map, ProviderContext context) throws Exception {
Object bean = map.get("bean");
SQL sql = new SQL();
sql.UPDATE(getTableName(bean));
fillFields(bean, sql, 2);
return sql.toString();
}
/**
* 删除(根据主键)
*
* @param context
* @return
* @throws Exception
*/
public String deleteByPrimaryKey(ProviderContext context) throws Exception {
Class beanClass = getClass(context);
SQL sql = new SQL();
sql.DELETE_FROM(getTableName(beanClass));
sql.WHERE(getPrimaryKey(beanClass) + "=#{primaryKey}");
return sql.toString();
}
/**
* 删除(根据对象)
*
* @param context
* @return
* @throws Exception
*/
public String deleteByBean(Map<String, Object> map, ProviderContext context) throws Exception {
Class beanClass = getClass(context);
SQL sql = new SQL();
sql.DELETE_FROM(getTableName(beanClass));
fillFields(map.get("bean"), sql, 4);
return sql.toString();
}
/**
* 查询根据主键
*
* @param map
* @return
* @throws Exception
*/
public String queryObjectByPrimaryKey(Map<String, Object> map, ProviderContext context) throws Exception {
Class beanClass = getClass(context);
SQL sql = new SQL();
fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
sql.FROM(getTableName(beanClass));
sql.WHERE(getPrimaryKey(beanClass) + "=#{primaryKey}");
return sql.toString();
}
/**
* 查询列表
*
* @param map
* @return
* @throws Exception
*/
public String queryListByBean(Map<String, Object> map, ProviderContext context) throws Exception {
Class beanClass = getClass(context);
SQL sql = new SQL();
fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
sql.FROM(getTableName(beanClass));
fillFields(map.get("bean"), sql, 3);
if (map.get("inMap") != null) {
fillInConditions(map, sql);
}
if (map.get("pageable") != null) {
fillPage(map, sql);
}
return sql.toString();
}
/**
* 更新
*
* @param map
* @return
* @throws Exception
*/
public String queryList(Map<String, Object> map, ProviderContext context) throws Exception {
Class beanClass = getClass(context);
SQL sql = new SQL();
fillSelectFields(beanClass, sql, map.get("fields") == null ? null : Stream.of((String[]) map.get("fields")).collect(Collectors.toSet()));
sql.FROM(getTableName(beanClass));
if (map.get("conditions") != null) {
fillConditions(map, sql);
}
if (map.get("inMap") != null) {
fillInConditions(map, sql);
}
if (map.get("pageable") != null) {
fillPage(map, sql);
}
return sql.toString();
}
/**
* 填充分页条件
*
* @param map
* @param sql
*/
private void fillPage(Map<String, Object> map, SQL sql) {
Pageable pageable = (Pageable) map.get("pageable");
if (pageable.getSort().isSorted()) {
for (Sort.Order order : pageable.getSort().toList()) {
sql.ORDER_BY(order.getProperty() + " " + order.getDirection());
}
}
sql.OFFSET(pageable.getOffset()).LIMIT(pageable.getPageSize());
}
/**
* 填充in条件
*
* @param map
* @param sql
*/
private void fillConditions(Map<String, Object> map, SQL sql) {
Map<String, Map<String,Object>> conditions = (Map<String, Map<String,Object>>) map.get("conditions");
for (Map.Entry<String, Map<String,Object>> entry : conditions.entrySet()) {
Map<String,Object> value = entry.getValue();
String operate = value.get("operate").toString();
sql.WHERE(entry.getKey() + operate +" #{conditions." + entry.getKey() + ".value}");
}
}
/**
* 填充in条件
*
* @param map
* @param sql
*/
private void fillInConditions(Map<String, Object> map, SQL sql) {
Map<String, List<String>> inMap = (Map<String, List<String>>) map.get("inMap");
for (Map.Entry<String, List<String>> entry : inMap.entrySet()) {
StringBuilder builder = new StringBuilder();
//entry.getValue().forEach(str -> builder.append("'").append(str).append("'").append(","));
List<String> value = entry.getValue();
for (int i = 0; i < value.size(); i++) {
builder.append("#{inMap." + entry.getKey() + "[" + i + "]}").append(",");
}
builder.deleteCharAt(builder.length() - 1);
sql.WHERE(entry.getKey() + " in (" + builder + ")");
}
}
/**
* 获取主键列表
*
* @param c
* @return
*/
private String getPrimaryKey(Class c) {
Field[] fields = c.getDeclaredFields();
for (Field field : fields) {
Id idColumn = field.getAnnotation(Id.class);
if (idColumn != null) {
return field.getName();
}
}
return null;
}
/**
* 获取字段并填充sql
*
* @param o
* @param sql
* @param sqlType 1添加,2更新,3查询,4删除
* @return
*/
private void fillFields(Object o, SQL sql, int sqlType) throws Exception {
if (o != null) {
Field[] fields = o.getClass().getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
String name = field.getName();
String columnName = getColumnName(field);
Id idColumn = field.getAnnotation(Id.class);
if (idColumn != null) {
if (sqlType == 2 || sqlType == 3 || sqlType == 4) {
//更新
sql.WHERE(columnName + "=#{bean." + name + "}");
}
continue;
}
Object o1 = field.get(o);
if (o1 != null) {
if (sqlType == 2) {
//更新
sql.SET(columnName + "=#{bean." + name + "}");
} else if (sqlType == 3 || sqlType == 4) {
//查询
sql.WHERE(columnName + "=#{bean." + name + "}");
}
}
}
}
}
/**
* 填充新增sql
*
* @param list
* @param sql
* @param updateFields 冲突时更新的字段
* @param updateFieldSql 填充后的更新sql
* @throws Exception
*/
private void fillInsert(List<Object> list, SQL sql, Set<String> updateFields, StringBuilder updateFieldSql) throws Exception {
//value的格式
StringBuilder pattern = new StringBuilder();
//列
Object o = list.get(0);
Field[] fields = o.getClass().getDeclaredFields();
for (Field field : fields) {
Id idColumn = field.getAnnotation(Id.class);
if (idColumn != null) {
continue;
}
field.setAccessible(true);
String name = field.getName();
Object o1 = field.get(o);
if (o1 != null) {
String columnName = getColumnName(field);
sql.INTO_COLUMNS(columnName);
pattern.append("#{list[%d].").append(name).append("},");
if (!CollectionUtils.isEmpty(updateFields) && updateFields.contains(columnName)) {
updateFieldSql.append(columnName + "=VALUES(" + columnName + "),");
}
}
}
if (pattern.length() > 0) {
pattern.deleteCharAt(pattern.length() - 1);
for (int i = 0; i < list.size(); i++) {
if (i > 0) {
sql.ADD_ROW();
}
String format = pattern.toString().replaceAll("%d", i + "");
sql.INTO_VALUES(format);
}
}
}
/**
* 获取列名
*
* @param field
* @return
*/
private String getColumnName(Field field) {
String columnName;
//判断是否有特殊注解
Column annotation = field.getDeclaredAnnotation(Column.class);
if (annotation != null) {
columnName = annotation.name();
} else {
columnName = humpToLine2(field.getName());
}
return columnName;
}
/**
* 获取查询展示字段
*
* @param o
* @return
*/
private void fillSelectFields(Class o, SQL sql, Set<String> selectFields) {
Field[] fields = o.getDeclaredFields();
for (Field field : fields) {
String columnName = getColumnName(field);
if (!CollectionUtils.isEmpty(selectFields)) {
if (selectFields.contains(columnName)) {
sql.SELECT(columnName + " AS " + field.getName());
}
} else {
sql.SELECT(columnName + " AS " + field.getName());
}
}
//字段错误抛错
if (StringUtils.isBlank(sql.toString())){
throw new RuntimeException("无有效字段");
}
}
private Class getClass(ProviderContext context){
Class mClass = context.getMapperType();
return (Class) ((ParameterizedType) (mClass.getGenericInterfaces()[0])).getActualTypeArguments()[0];
}
/**
* 获取表名
*
* @param c
* @return
*/
private String getTableName(Class c){
Annotation[] annotations = c.getDeclaredAnnotations();
for (Annotation annotation : annotations) {
if (annotation instanceof Table){
return ((Table) annotation).name();
}
}
String className = c.getSimpleName();
return humpToLine2(className);
}
/**
* 获取表名
*
* @param o
* @return
*/
private String getTableName(Object o){
Table annotation = o.getClass().getAnnotation(Table.class);
if (annotation != null){
return annotation.name();
}
String className = o.getClass().getSimpleName();
return humpToLine2(className);
}
/**
* 将属性名根据驼峰转换为_连接的列名
*
* @param str
* @return
*/
private static String humpToLine2(String str) {
Matcher matcher = HUMP_PATTERN.matcher(str);
StringBuffer sb = new StringBuffer();
while (matcher.find()) {
matcher.appendReplacement(sb, "_" + matcher.group(0).toLowerCase());
}
matcher.appendTail(sb);
return sb.toString();
}
}
测试dao
import com.johny.test.bean.TestBean;
import com.johny.test.dao.MybatisBaseDao;
import com.johny.test.util.MybatisUtil;
import org.apache.ibatis.annotations.*;
import java.util.List;
import java.util.Map;
/**
* 测试dao
*
* @author johny
* @date 2022/8/30
*/
public interface TestDao extends MybatisBaseDao<TestBean,Long> {
}
测试service
下面是使用Pageable的示例代码,可以参考一下
/**
* @author johny
* @date 2022/9/3
*/
@Service
@Lazy
public class UpdateUnionIdService {
@Autowired
TestDao testDao;
/**
*测试查询
*/
public void query() throws Exception {
//查询条件
TestBean testBean = new TestBean();
testBean.setId(16021L);
//排序条件
Sort.Order orderId = new Sort.Order(Sort.Direction.ASC, "id");
Sort.Order orderOpenid = new Sort.Order(Sort.Direction.DESC, "name");
Sort sort = Sort.by(orderId, orderOpenid);
//分页
Pageable page=PageRequest.of(0, 10, sort)
//in条件
Map<String, List<String>> inMap = new HashMap<>();
List<String> strings = new ArrayList<>();
strings.add("测试");
strings.add("4");
inMap.put("name", strings);
//查询 create_time为查询展示字段
List<TestBean> beans = testDao.queryListByBean(null, inMap, page, "create_time");
if (!CollectionUtils.isEmpty(beans)) {
System.out.println(JSON.toJSONString(beans));
}
Map<String, Map<String,Object>> conditions=new HashMap<>();
Map<String,Object> condition=new HashMap<>();
condition.put("operate",">");
condition.put("value",0);
conditions.put("id",condition);
Map<String,Object> condition2=new HashMap<>();
condition2.put("operate","<");
condition2.put("value",10);
conditions.put("id",condition2);
List<TestBean> beans = testDao.queryList(conditions, null, page, "create_time");
}
}
最后
mybatis设置map-underscore-to-camel-case最好为true
以上就是我自己抽取的公用方法,没有写更复杂的查询,我想了一下,再引入条件工具类和更新工具类的话就不如写复杂查询的时候直接@select写sql来的方便了,所以没继续扩展。
2024年1月27日更新
1.增加字段错误抛异常
2.增加支持普通where条件方法
希望大家可以按照示例写更多适合自己扩展方法