batch-insert-JdbcTemplate

本文介绍了如何使用Java的JdbcTemplate和自定义注解实现数据库的批量写入,包括更新和插入操作。通过注解处理表和字段信息,动态生成SQL语句,以提高效率和灵活性。示例代码展示了如何在SpringBoot项目中应用这些工具。
摘要由CSDN通过智能技术生成
前言

后台开发中,批量往数据库写数据是一个很常见的功能,下面就简单实现一下使用 JdbcTemplate 来 batch 写入。

实现介绍
添加依赖

在项目的 pom.xml 中配置 JdbcTemplate 及 mysql 相关的依赖

<dependency>
    <groupId>mysql</groupId>
    <artifactId>mysql-connector-java</artifactId>
    <version>5.1.46</version>
</dependency>
<dependency>
    <groupId>org.mybatis.spring.boot</groupId>
    <artifactId>mybatis-spring-boot-starter</artifactId>
    <version>2.2.0</version>
</dependency>
自定义 table 注解

新建一个自定义注解,作用于类上面,用于备注 table 信息

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.List;

/**
 * Description:需要进行入库操作的表实体注解
 *
 * @author fy
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface DbTable {

    /***
     * <p>
     * 表名
     * </p>
     * @author fy
     *
     * @return java.lang.String
     */
    String tableName();

    /***
     * <p>
     * 表的主键
     * </p>
     * @author fy
     *
     * @return java.lang.String
     */
    String primaryKey() default "";

    /***
     * <p>
     * 不参与操作的黑名单字段属性
     * </p>
     * @author fy
     *
     * @return java.lang.String[]
     */
    String[] blackIgnoreFieldList() default {};
}

自定义 column 注解

新建一个自定义注解,作用于类属性上面,用于备注 column 信息

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * Description:需要进行入库操作的表字段注解
 *
 * @author fy
 * @version 1.0
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface DbColumn {

    /***
     * <p>
     * 该字段是否允许为空(NULL)
     * </p>
     * @author fy
     *
     * @return boolean
     */
    boolean nullable() default false;

    /***
     * <p>
     * 新增时忽略
     * </p>
     * @author fy
     *
     * @return boolean
     */
    boolean ignoredInsert() default false;

    /***
     * <p>
     * 更新时忽略
     * </p>
     * @author fy
     *
     * @return boolean
     */
    boolean ignoredUpdate() default false;

    /***
     * <p>
     * 字段为 NULL 时候的默认值,不能漏了''
     * </p>
     * @author fy
     *
     * @return java.lang.String
     */
    String defaultValue() default "";
}
逻辑处理工具类

利用自定义的两个注解搭配反射获取表及表字段相关信息,然后通过字符串拼装出 batch 操作的 sql,最后通过 jdbcTemplate.execute 来执行该 sql 进行批量操作。

该类需要加入 spring 管理。

import com.fy.util.NormalUtil;
import com.fy.util.db.ann.DbColumn;
import com.fy.util.db.ann.DbTable;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Description:db 操作工具类
 *
 * @author fy
 * @version 1.0
 */
@Component
public class DbUtil {

    @Value("${db-util.sql.log:false}")
    private boolean isSqlLog;

    @Autowired
    private JdbcTemplate jdbcTemplate;

    /***
     * <p>
     * 批量执行的批次默认大小
     * </p>
     */
    private static final int BATCH_SIZE_DEFAULT = 5000;

    private static final String EMPTY_COLUMN_DATA = "";

    /***
     * <p>
     * 批量更新(通过主键)
     * </p>
     * @author fy
     *
     * @param list 批量更新的数据
     */
    public void batchUpdateByPrimaryKey(List<?> list) {
        batchUpdateByPrimaryKey(list, BATCH_SIZE_DEFAULT);
    }

    /***
     * <p>
     * 批量更新(通过主键)
     * </p>
     * @author fy
     *
     * @param list 批量更新的数据
     * @param batchSize 批量批次的大小
     */
    public void batchUpdateByPrimaryKey(List<?> list, int batchSize) {
        ClassBean classBean = checkAndLoadClass(list);
        if (classBean == null) {
            return;
        }
        DbTable dbTable = classBean.getDbTable();
        String primaryKey = dbTable.primaryKey();
        if ("".equals(primaryKey.trim())) {
            // 配置没有指定主键
            return;
        }
        batchUpdate(list, Collections.singletonList(primaryKey), batchSize);
    }

    /***
     * <p>
     * 批量更新
     * </p>
     * @author fy
     *
     * @param list 批量更新的数据
     * @param whereColList 作为 where 条件的字段
     */
    public void batchUpdate(List<?> list, List<String> whereColList) {
        batchUpdate(list, whereColList, BATCH_SIZE_DEFAULT);
    }

    /***
     * <p>
     * 批量更新
     * </p>
     * @author fy
     *
     * @param list 批量更新的数据
     * @param whereColList 作为 where 条件的字段
     * @param batchSize 批量批次的大小
     */
    public void batchUpdate(List<?> list, List<String> whereColList, int batchSize) {
        ClassBean classBean = checkAndLoadClass(list);
        if (classBean == null) {
            return;
        }
        DbTable dbTable = classBean.getDbTable();
        String tableName = dbTable.tableName();
        Map<String, FieldBean> map = classBean.getFieldBeanMap();
        if (map == null || map.size() == 0) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (Object o : list) {
            StringBuilder toUpdateSb = new StringBuilder();
            try {
                boolean hasToUpdate = isHasToUpdate(whereColList, map, o, toUpdateSb);
                if (hasToUpdate) {
                    sb.append(" UPDATE ").append(tableName).append(" SET ")
                            .append(toUpdateSb.toString());
                } else {
                    // 说明没有需要更新的字段
                    continue;
                }
                loadWhere(whereColList, map, sb, o);
                sb.append("; ");
            } catch (IllegalAccessException e) {
                // 获取值失败
                continue;
            }
            count ++;
            if (count == batchSize) {
                // 达到了批次数量,开始执行,并清理 sql
                executeSql(sb.toString());
                count = 0;
                sb.delete(0, sb.length());
            }
        }
        if (count > 0) {
            // 执行未执行的批次
            executeSql(sb.toString());
            sb.delete(0, sb.length());
        }
    }

    private boolean isHasToUpdate(List<String> whereColList,
                                  Map<String, FieldBean> map,
                                  Object data,
                                  StringBuilder toUpdateSb)
            throws IllegalAccessException {
        boolean hasToUpdate = false;
        for (Map.Entry<String, FieldBean> entry : map.entrySet()) {
            String name = entry.getKey();
            if (whereColList != null && whereColList.contains(name)) {
                // 这个字段在 where 条件中
                continue;
            }
            FieldBean fieldBean = entry.getValue();
            DbColumn dbColumn = fieldBean.getDbColumn();
            if (dbColumn != null && dbColumn.ignoredUpdate()) {
                // 该字段设置了忽略更新
                continue;
            }
            Object colData = fieldBean.getField().get(data);
            if (colData == null) {
                // 该字段值为空
                if (dbColumn != null && dbColumn.nullable()) {
                    // 设置了该字段允许为空
                    if (hasToUpdate) {
                        toUpdateSb.append(", ");
                    }
                    toUpdateSb.append(name).append(" = NULL ");
                    hasToUpdate = true;
                }
            }
            else {
                // 该字段值不为空
                if (hasToUpdate) {
                    toUpdateSb.append(", ");
                }
                toUpdateSb.append(name).append(" = ")
                        .append("'").append(getData(colData)).append("'");
                hasToUpdate = true;
            }
        }
        return hasToUpdate;
    }

    private void loadWhere(List<String> whereColList,
                           Map<String, FieldBean> map,
                           StringBuilder sqlSb,
                           Object data)
            throws IllegalAccessException {
        if (whereColList != null && whereColList.size() > 0) {
            StringBuilder whereSb = new StringBuilder();
            boolean hasWhere = false;
            for (String colName : whereColList) {
                FieldBean fieldBean = map.get(colName);
                if (fieldBean == null) {
                    // 不存在这个字段
                    continue;
                }
                if (hasWhere) {
                    whereSb.append(" AND ");
                }
                Object colData = fieldBean.getField().get(data);
                whereSb.append(colName).append(" = ")
                        .append("'").append(getData(colData)).append("'");
                hasWhere = true;
            }
            if (hasWhere) {
                sqlSb.append(" WHERE ").append(whereSb.toString());
            }
        }
    }

    /***
     * <p>
     * 批量新增(非空字段)
     * </p>
     * @author fy
     *
     * @param list 批量新增的数据
     */
    public void batchInsertSelective(List<?> list) {
        batchInsertSelective(list, BATCH_SIZE_DEFAULT);
    }

    /***
     * <p>
     * 批量新增(非空字段)
     * </p>
     * @author fy
     *
     * @param list 批量新增的数据
     * @param batchSize 批量的批次大小
     */
    public void batchInsertSelective(List<?> list, int batchSize) {
        ClassBean classBean = checkAndLoadClass(list);
        if (classBean == null) {
            return;
        }
        DbTable dbTable = classBean.getDbTable();
        String tableName = dbTable.tableName();
        Map<String, FieldBean> map = classBean.getFieldBeanMap();
        if (map == null || map.size() == 0) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (Object o : list) {
            try {
                StringBuilder columnSb = new StringBuilder();
                StringBuilder dataSb = new StringBuilder();
                boolean hasToInsert = isHasToInsert(map, o, columnSb, dataSb);
                if (hasToInsert) {
                    sb.append(" INSERT INTO ").append(tableName).append(" (")
                            .append(columnSb.toString()).append(") VALUES (")
                            .append(dataSb.toString()).append("); ");
                } else {
                    // 没有需要插入的数据
                    continue;
                }
            } catch (IllegalAccessException ignored) {
                // 获取值发生异常
                continue;
            }
            count ++;
            if (count == batchSize) {
                // 达到了批次数量,开始执行,并清理 sql
                executeSql(sb.toString());
                count = 0;
                sb.delete(0, sb.length());
            }
        }
        if (count > 0) {
            // 执行未执行的批次
            executeSql(sb.toString());
            sb.delete(0, sb.length());
        }
    }

    private boolean isHasToInsert(Map<String, FieldBean> map,
                                  Object data,
                                  StringBuilder columnSb,
                                  StringBuilder dataSb)
            throws IllegalAccessException {
        boolean hasToInsert = false;
        for (Map.Entry<String, FieldBean> entry : map.entrySet()) {
            String columnName = entry.getKey();
            FieldBean fieldBean = entry.getValue();
            Object colData = fieldBean.getField().get(data);
            DbColumn dbColumn = fieldBean.getDbColumn();
            if (dbColumn != null && dbColumn.ignoredInsert()) {
                continue;
            }
            if (colData == null) {
                // 该字段值为空
                if (dbColumn != null && !EMPTY_COLUMN_DATA.equals(dbColumn.defaultValue())) {
                    // 该字段配置了默认值
                    if (hasToInsert) {
                        columnSb.append(",");
                        dataSb.append(",");
                    }
                    columnSb.append(columnName);
                    dataSb.append(dbColumn.defaultValue());
                    hasToInsert = true;
                } else {
                    if (dbColumn != null && dbColumn.nullable()) {
                        // 设置了该字段允许为空
                        if (hasToInsert) {
                            columnSb.append(",");
                            dataSb.append(",");
                        }
                        columnSb.append(columnName);
                        dataSb.append(" NULL ");
                        hasToInsert = true;
                    }
                }
            } else {
                // 字段值不为空
                if (hasToInsert) {
                    columnSb.append(",");
                    dataSb.append(",");
                }
                columnSb.append(columnName);
                dataSb.append("'").append(getData(colData)).append("'");
                hasToInsert = true;
            }
        }
        return hasToInsert;
    }

    /***
     * <p>
     * 批量新增(所有字段)
     * </p>
     * @author fy
     *
     * @param list 批量新增的数据
     */
    public void batchInsert(List<?> list) {
        batchInsert(list, BATCH_SIZE_DEFAULT);
    }

    /***
     * <p>
     * 批量新增(所有字段)
     * </p>
     * @author fy
     *
     * @param list 批量新增的数据
     * @param batchSize 批量的批次大小
     */
    public void batchInsert(List<?> list, int batchSize) {
        ClassBean classBean = checkAndLoadClass(list);
        if (classBean == null) {
            return;
        }
        DbTable dbTable = classBean.getDbTable();
        String tableName = dbTable.tableName();
        Map<String, FieldBean> map = classBean.getFieldBeanMap();
        if (map == null || map.size() == 0) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        StringBuilder columnSb = new StringBuilder();
        // 先组装出字段字符串
        boolean isNotFirst = false;
        for (Map.Entry<String, FieldBean> entry : map.entrySet()) {
            FieldBean fieldBean = entry.getValue();
            DbColumn dbColumn = fieldBean.getDbColumn();
            if (dbColumn != null && dbColumn.ignoredInsert()) {
                continue;
            }
            if (isNotFirst) {
                columnSb.append(",");
            }
            columnSb.append(entry.getKey());
            isNotFirst = true;
        }
        int count = 0;
        StringBuilder dataSb = new StringBuilder();
        for (Object o : list) {
            if (sb.length() == 0) {
                sb.append(" INSERT INTO ").append(tableName).append(" (")
                        .append(columnSb.toString()).append(") VALUES ");
            }
            if (count > 0) {
                dataSb.append(",");
            }
            try {
                dataSb.append("(");
                boolean isNotFirstColumn = false;
                for (Map.Entry<String, FieldBean> entry : map.entrySet()) {
                    FieldBean fieldBean = entry.getValue();
                    DbColumn dbColumn = fieldBean.getDbColumn();
                    if (dbColumn != null && dbColumn.ignoredInsert()) {
                        continue;
                    }
                    Object data = fieldBean.getField().get(o);
                    if (isNotFirstColumn) {
                        dataSb.append(",");
                    }
                    if (data == null) {
                        // 该字段值为空
                        if (dbColumn != null && !EMPTY_COLUMN_DATA.equals(dbColumn.defaultValue())) {
                            // 该字段配置了默认值
                            dataSb.append(dbColumn.defaultValue());
                        } else {
                            dataSb.append(" NULL ");
                        }
                    } else {
                        // 字段值不为空
                        dataSb.append("'").append(getData(data)).append("'");
                    }
                    isNotFirstColumn = true;
                }
                dataSb.append(")");
            } catch (IllegalAccessException ignored) {
                // 获取值发生异常
                continue;
            }
            count ++;
            if (count == batchSize) {
                // 达到了批次数量,开始执行,并清理 sql
                sb.append(dataSb.toString()).append(";");
                executeSql(sb.toString());
                count = 0;
                dataSb.delete(0, dataSb.length());
                sb.delete(0, sb.length());
            }
        }
        if (count > 0) {
            // 执行未执行的批次
            sb.append(dataSb.toString()).append(";");
            executeSql(sb.toString());
            dataSb.delete(0, dataSb.length());
            sb.delete(0, sb.length());
        }
    }

    /***
     * <p>
     * 执行 sql
     * </p>
     * @author fy
     *
     * @param sql 待执行的 sql 语句
     */
    private void executeSql(String sql) {
        if (isSqlLog) {
            System.err.println(sql);
        }
        jdbcTemplate.execute(sql);
    }

    private Object getData(Object data) {
        if (data == null) {
            return null;
        }
        if (data instanceof Date) {
            data = NormalUtil.formatDate((Date) data, "yyyy-MM-dd HH:mm:ss");
        }
        return data;
    }

    /***
     * <p>
     * 基础检查类信息
     * </p>
     * @author fy
     *
     * @param list 数据列表
     * @return com.fy.util.db.DbUtil.ClassBean
     */
    private ClassBean checkAndLoadClass(List<?> list) {
        if (list == null || list.size() == 0) {
            // 数据为空
            return null;
        }
        Class<?> clazz = list.get(0).getClass();
        DbTable dbTable = clazz.getAnnotation(DbTable.class);
        if (dbTable == null) {
            // 必须有这个注解
            throw new RuntimeException(clazz.getName() + " 类缺失【DbTable】注解!");
        }
        return getTransClass(clazz);
    }

    private static final Map<String, ClassBean> CLASS_TRANS_MAP_CACHE = new ConcurrentHashMap<>();
    private static Set<String> BLACK_FIELD_IGNORED_SET =
            Collections.singleton("serialVersionUID");

    /***
     * <p>
     * 通过 class 类获取该类中相关信息
     * </p>
     * @author fy
     *
     * @param clazz 操作的 class 类
     * @return com.fy.util.db.DbUtil.ClassBean
     */
    private ClassBean getTransClass(Class<?> clazz) {
        String clazzName = clazz.getName();
        if (CLASS_TRANS_MAP_CACHE.containsKey(clazzName)) {
            return CLASS_TRANS_MAP_CACHE.get(clazzName);
        }
        ClassBean classBean = CLASS_TRANS_MAP_CACHE.computeIfAbsent(clazzName,
                k -> new ClassBean());
        Field[] fields = clazz.getDeclaredFields();
        Map<String, FieldBean> fieldBeanMap = new HashMap<>(fields.length);
        DbTable dbTable = clazz.getAnnotation(DbTable.class);
        if (dbTable == null) {
            throw new RuntimeException(clazz.getName() + " 类缺失【DbTable】注解!");
        }
        String[] blackIgnoreFieldList = dbTable.blackIgnoreFieldList();
        BLACK_FIELD_IGNORED_SET.addAll(Arrays.asList(blackIgnoreFieldList));
        for (Field field : fields) {
            if (BLACK_FIELD_IGNORED_SET.contains(field.getName())) {
                continue;
            }
            field.setAccessible(true);
            // key 是下划线格式
            fieldBeanMap.put(NormalUtil.transUpper2Under(field.getName()),
                    new FieldBean(field, field.getAnnotation(DbColumn.class)));
        }
        classBean.setClazz(clazz);
        classBean.setDbTable(dbTable);
        classBean.setFieldBeanMap(fieldBeanMap);
        return classBean;
    }

    /***
     * <p>
     * 储存 class 相关信息的 bean
     * </p>
     * @author fy
     *
     */
    @Getter
    @Setter
    private static class ClassBean {
        Class<?> clazz;
        DbTable dbTable;
        Map<String, FieldBean> fieldBeanMap;
    }

    /***
     * <p>
     * 储存字段相关信息的 bean
     * </p>
     * @author fy
     *
     */
    @Getter
    @Setter
    @NoArgsConstructor
    private static class FieldBean {
        Field field;
        DbColumn dbColumn;

        FieldBean(Field field, DbColumn dbColumn) {
            this.field = field;
            this.dbColumn = dbColumn;
        }
    }
}

方法概览

预览工具类中的几个 public 方法

// 批量更新(通过主键)
public void batchUpdateByPrimaryKey(List<?> list);
// 批量更新(通过主键)
public void batchUpdateByPrimaryKey(List<?> list, int batchSize);
// 批量更新
public void batchUpdate(List<?> list, List<String> whereColList);
// 批量更新
public void batchUpdate(List<?> list, List<String> whereColList, int batchSize);
// 批量新增(非空字段)
public void batchInsertSelective(List<?> list);
// 批量新增(非空字段)
public void batchInsertSelective(List<?> list, int batchSize);
// 批量新增(所有字段)
public void batchInsert(List<?> list);
// 批量新增(所有字段)
public void batchInsert(List<?> list, int batchSize);
使用

1、先在需批量操作的 bean 类上添加相关注解

import com.fy.util.db.ann.DbColumn;
import com.fy.util.db.ann.DbTable;
import lombok.Getter;
import lombok.Setter;
import java.util.Date;

@Getter
@Setter
@DbTable(tableName = "tb_user", primaryKey = "id")
public class UserDto {
    /**
     * 主键
     */
    private String id;

    /**
     * 用户名
     */
    private String userName;

    /**
     * 密码
     */
    private String userPassword;

    /**
     * 密码盐
     */
    private String userSalt;

    /**
     * 手机号
     */
    private String phone;

    /**
     * 创建人
     */
    @DbColumn(defaultValue = "'test-user'")
    private String createUser;

    /**
     * 创建时间
     */
    @DbColumn(ignoredInsert = true)
    private Date createTime;

    /**
     * 更新人
     */
    @DbColumn(defaultValue = "'test-user'")
    private String updateUser;

    /**
     * 更新时间
     */
    @DbColumn(ignoredUpdate = true, ignoredInsert = true)
    private Date updateTime;
}

2、使用 spring 注入,然后调用即可

@Autowired
private DbUtil dbUtil;

public void testDbUtilBatchInsertUser(int count) {
    long t1 = System.currentTimeMillis();
    dbUtil.batchInsert(getUserList(count));
    System.out.println("【DbUtil】插入条数:【" + count + "】耗时:【"
            + (System.currentTimeMillis() - t1) + "】");
}
结语

到此,使用 JdbcTemplate 来 batch 写入数据的实现就介绍完了,后续继续其他方式的批量写入 …

如果您看到了这里,欢迎和我沟通交流!
             一个95后码农

个人博客:fy-blog

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值