大量数据插入MySQL的工具类,需配合MybatisPlus注解使用

package com.example.mp.util;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DatePattern;
import cn.hutool.core.date.DateTime;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.lang.Snowflake;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.DbUtil;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.extern.slf4j.Slf4j;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;

/**
 * 用于操作大量数据的MybatisPlus对象,详情看方法注释
 *
 * @author 王刚
 * @since 2022年03月01日 17:39
 */
@Slf4j
public class BigDataSqlUtil {

    /**
     * 根据传入的 Collection 自动生成sql并进行批量更新
     * 使用 传入的Collection泛型中的 TableName注解 作为表名,如果没有TableName注解,则不执行。TableId注解或TableField注解作为字段名
     * platform-dao 模块的test模块下的 CodeGenerator 可以生成需要的泛型,其他的不建议使用
     *
     * @param collection 集合
     * @param dataSource 数据源
     * @param <T>        传入的 Collection 泛型,
     */
    public static <T> long saveBatchUseSqlStr(Collection<T> collection, DataSource dataSource) {
        return saveBatchUseSqlStr(collection, dataSource, 1000, null, false, false);
    }

    /**
     * 根据传入的 Collection 自动生成sql并进行批量更新
     * 使用 传入的Collection泛型中的 TableName注解 作为表名,如果没有TableName注解,则不执行。TableId注解或TableField注解作为字段名
     * platform-dao 模块的test模块下的 CodeGenerator 可以生成需要的泛型,其他的不建议使用
     *
     * @param collection            集合
     * @param dataSource            数据源
     * @param closeForeignKeyChecks 是否在插入时关闭外键检查
     * @param <T>                   传入的 Collection 泛型,
     */
    public static <T> long saveBatchUseSqlStr(Collection<T> collection, DataSource dataSource, boolean closeForeignKeyChecks) {
        return saveBatchUseSqlStr(collection, dataSource, 1000, null, false, closeForeignKeyChecks);
    }

    /**
     * 根据传入的 Collection 自动生成sql并进行批量更新
     * 使用 传入的Collection泛型中的 TableName注解 作为表名,如果没有TableName注解,则不执行。TableId注解或TableField注解作为字段名
     * platform-dao 模块的test模块下的 CodeGenerator 可以生成需要的泛型,其他的不建议使用
     *
     * @param collection 集合
     * @param dataSource 数据源
     * @param batchSize  每批更新的条数
     * @param <T>        传入的 Collection 泛型,
     */
    public static <T> long saveBatchUseSqlStr(Collection<T> collection, DataSource dataSource, int batchSize) {
        return saveBatchUseSqlStr(collection, dataSource, batchSize, null, false, false);
    }

    /**
     * 根据传入的 Collection 自动生成sql并进行批量更新
     * 使用 传入的Collection泛型中的 TableName注解 作为表名,如果没有TableName注解,则不执行。TableId注解或TableField注解作为字段名
     * platform-dao 模块的test模块下的 CodeGenerator 可以生成需要的泛型,其他的不建议使用
     *
     * @param collection            集合
     * @param dataSource            数据源
     * @param batchSize             每批更新的条数
     * @param idType                只有在Collection泛型中配置了TableId注解才会生效
     * @param toUnderlineCase       是否把驼峰转为下划线。只有在Collection泛型的 TableId注解和TableField注解都为空的情况下才会使用
     * @param closeForeignKeyChecks 是否在插入时关闭外键检查
     * @param <T>                   传入的 Collection 泛型,
     */
    public static <T> long saveBatchUseSqlStr(Collection<T> collection, DataSource dataSource, int batchSize, IdType idType, boolean toUnderlineCase, boolean closeForeignKeyChecks) {
        if (CollUtil.isEmpty(collection)) {
            return 0;
        }

        Class<?> beanClass = collection.iterator().next().getClass();

        TableName tableNameAnno = beanClass.getAnnotation(TableName.class);
        if (tableNameAnno == null) {
            return 0;
        }
        String tableName = tableNameAnno.value();

        // 尽量使用SQL拼接,效率更高
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            // 获取数据库连接
            conn = dataSource.getConnection();
            // 设置不自动提交
            conn.setAutoCommit(false);
            // 设置关闭外键检查
            if (closeForeignKeyChecks) {
                PreparedStatement statement = conn.prepareStatement("SET FOREIGN_KEY_CHECKS = 0;");
                statement.execute();
            }

            long updateCount = 0;
            Set<Integer> jumpSet = new HashSet<>();
            Set<Integer> typeHandlerValueIndex = new HashSet<>();
            Map<Integer, IdType> idTypeMap = new HashMap<>();
            StringBuilder columnsSql = new StringBuilder().append("(");

            Field[] fields = ReflectUtil.getFields(beanClass);
            for (int i = 0; i < fields.length; i++) {
                Field field = fields[i];
                if ("serialVersionUID".equalsIgnoreCase(field.getName())) {
                    jumpSet.add(i);
                    continue;
                }
                String columnName = "";
                TableField tableFieldAnno = field.getAnnotation(TableField.class);
                TableId tableIdAnno = field.getAnnotation(TableId.class);

                if (tableFieldAnno == null || tableIdAnno == null) {
                    if (toUnderlineCase) {
                        columnName = StrUtil.toUnderlineCase(field.getName());
                    } else {
                        columnName = field.getName();
                    }
                }

                if (tableFieldAnno != null) {
                    // 如果配置了typeHandler并且开启了自动构建Map,就标记该下标位置的值为一个对象中的ID
                    if (tableFieldAnno.typeHandler() != null && tableNameAnno.autoResultMap()) {
                        typeHandlerValueIndex.add(i);
                    }
                    columnName = tableFieldAnno.value();
                }

                if (tableIdAnno != null) {
                    idTypeMap.put(i, idType != null ? idType : tableIdAnno.type());
                    // 主键自增跳过该字段
                    if (IdType.AUTO.equals(idType)) {
                        jumpSet.add(i);
                        continue;
                    }
                    columnName = tableIdAnno.value();
                }

                if (StrUtil.isBlank(columnName)) {
                    jumpSet.add(i);
                    continue;
                }

                boolean contains = columnName.contains("`");
                if (!contains) {
                    columnsSql.append("`");
                }

                columnsSql.append(columnName);

                if (!contains) {
                    columnsSql.append("`");
                }

                if (i != fields.length - 1) {
                    columnsSql.append(", ");
                } else {
                    columnsSql.append(")");
                }
            }

            Snowflake snowflake = IdUtil.getSnowflake();
            // 这里我们自己实现分批查询,可以掌握进度
            List<List<T>> split = CollUtil.split(collection, batchSize);
            for (List<T> list : split) {
                StringBuilder sql = new StringBuilder("insert into `").append(tableName).append("` ").append(columnsSql).append(" values ");
                for (T t : list) {
                    StringBuilder valuesSql = new StringBuilder().append("(");
                    for (int i = 0; i < fields.length; i++) {
                        if (jumpSet.contains(i)) {
                            continue;
                        }

                        Field field = fields[i];

                        Object value;
                        IdType type = idTypeMap.get(i);
                        if (type != null) {
                            if (IdType.ASSIGN_ID.equals(type)) {
                                value = snowflake.nextIdStr();
                                ReflectUtil.setFieldValue(t, field, value);
                            } else if (IdType.ASSIGN_UUID.equals(type)) {
                                value = IdUtil.fastSimpleUUID();
                                ReflectUtil.setFieldValue(t, field, value);
                            } else {
                                value = ReflectUtil.getFieldValue(t, field);
                            }
                        } else {
                            value = ReflectUtil.getFieldValue(t, field);

                            // 如果typeHandlerValueIndex中包含当前下标,就取对象中的ID值
                            if (typeHandlerValueIndex.contains(i)) {
                                value = ReflectUtil.getFieldValue(value, "id");
                            }
                        }

                        if (value instanceof Date) {
                            value = DateUtil.format((Date) value, DatePattern.NORM_DATETIME_PATTERN);
                        }

                        if (StrUtil.isBlankIfStr(value)) {
                            valuesSql.append("null");
                        } else if (value instanceof Boolean) {
                            valuesSql.append(value);
                        } else {
                            valuesSql.append("'").append(value.toString()).append("'");
                        }

                        if (i != fields.length - 1) {
                            valuesSql.append(", ");
                        } else {
                            valuesSql.append("), ");
                        }
                    }
                    sql.append(valuesSql);
                }

                long start = System.currentTimeMillis();
                String executeSql = sql.substring(0, sql.lastIndexOf(","));
                ps = conn.prepareStatement(executeSql);
                int executeUpdate = ps.executeUpdate();
                log.debug(StrUtil.format("\n Consume Time:{} ms {}\nExecute SQL:{}\n\n", System.currentTimeMillis() - start, DateTime.now().toStringDefaultTimeZone(), executeSql));
                updateCount += executeUpdate;
            }
            conn.commit();

            return updateCount;
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            // 设置开启外键检查
            if (closeForeignKeyChecks) {
                try {
                    PreparedStatement statement = conn.prepareStatement("SET FOREIGN_KEY_CHECKS = 1;");
                    statement.execute();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
            }
            // 关闭连接
            DbUtil.close(ps, conn);
        }
    }

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值