Spring Data JPA批量插入过慢及其优化 —— 泛型提炼公用batchSave方法、引入多线程

 之前只针对一个实例进行插入保存,详情移步:

Spring Data JPA批量插入过慢及其优化 —— 自定义Repositoryicon-default.png?t=L892https://blog.csdn.net/tfstone/article/details/113741890

现在我们需要的是公用的batchSave方法——当然是使用泛型啦,在原有的基础上进行稍加改造:

一、新增批量导入接口BatchSaveRepository

isSave: true - save, false - update 

package com.easemob.oa.persistence.jpa;

import org.springframework.data.repository.NoRepositoryBean;
import java.util.List;

@NoRepositoryBean
public interface BatchSaveRepository<T> {

    <S extends T> List<S> batchSave(Iterable<S> entities,Boolean isSave);

}

二、新增接口实现类BatchSaveRepositoryImpl

这里需要注意EntityManager是线程不安全的,多线程使用需要注意;

那么如何获取线程安全的EntityManager(简称em)

由于EntityManagerFactory(简称emf)是线程安全的,在创建线程时通过emf为每个线程单独获取em即可。

那么如何获取EntityManagerFactory?

一般来说获取EntityManagerFactory需要通过读取配置文件中指定的persistence-unit-name,来动态获取;

指定persistence-unit-name有两种方式:*配置Persistence.xml、*配置JpaConfig.java

META-INF/persistence.xml

<?xml version="1.0" encoding="UTF-8"?>
<persistence xmlns="http://java.sun.com/xml/ns/persistence" version="2.0">
    <!--需要配置persistence-unit节点
        持久化单元:
            name:持久化单元名称
            transaction-type:事务管理的方式
                    JTA:分布式事务管理(不同的表分不到不同的数据库,使用分布式事务管理)
                    RESOURCE_LOCAL:本地事务管理
    -->
    <persistence-unit name="turnfly" transaction-type="RESOURCE_LOCAL">
        <!--jpa的实现方式 -->
        <provider>org.hibernate.jpa.HibernatePersistenceProvider</provider>

        <!--可选配置:配置jpa实现方的配置信息-->
        <properties>
            <!-- 数据库信息
                用户名,javax.persistence.jdbc.user
                密码,  javax.persistence.jdbc.password
                驱动,  javax.persistence.jdbc.driver
                数据库地址   javax.persistence.jdbc.url
            -->
            <property name="javax.persistence.jdbc.user" value="oatransfer"/>
            <property name="javax.persistence.jdbc.password" value="qwert"/>
            <property name="javax.persistence.jdbc.driver" value="com.oscar.Driver"/>
            <property name="javax.persistence.jdbc.url" value="jdbc:oscar://x.x.x.x:2003/OSRDB?useSSL=false"/>

            <!--配置jpa实现方(hibernate)的配置信息
                显示sql           :   false|true
                自动创建数据库表    :  hibernate.hbm2ddl.auto
                        create      : 程序运行时创建数据库表(如果有表,先删除表再创建)
                        update      :程序运行时创建表(如果有表,不会创建表)
                        none        :不会创建表

            -->
            <!--显示sql-->
            <property name="hibernate.show_sql" value="true" />
            <property name="hibernate.dialect" value="org.hibernate.dialect.OracleDialect" />
            <!--自动创建数据库表-->
            <property name="hibernate.hbm2ddl.auto" value="update" />
        </properties>
    </persistence-unit>
</persistence>

JpaConfig.java

package com.easemob.oa.persistence.config;

import org.springframework.boot.autoconfigure.jdbc.DataSourceBuilder;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.datasource.DriverManagerDataSource;
import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean;
import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter;

import javax.sql.DataSource;
import java.util.Properties;

@Configuration
public class JpaConfig {

   /**
     * DateSource Config
     * */
    @Bean
    public DataSource getDataSource() {
        DriverManagerDataSource dataSource = new DriverManagerDataSource();
        dataSource.setDriverClassName("com.oscar.Driver");
        dataSource.setUrl("jdbc:oscar://127.0.0.1:2003/OSRDB_TF?useSSL=false&rewriteBatchedStatements=TRUE");
        dataSource.setUsername("oatransfer");
        dataSource.setPassword("0098");
        return dataSource;
    }
    
    @Bean
    @ConfigurationProperties(prefix="oa-server.second-datasource")
    public DataSource getSecondDataSource() {
        return DataSourceBuilder.create().build();
    }

    @Bean
    public LocalContainerEntityManagerFactoryBean entityManagerFactory(DataSource dataSource) {
        // 设置是否生成DDL被已经被初始化后,创建/更新所有相关表
        HibernateJpaVendorAdapter vendorAdapter = new HibernateJpaVendorAdapter();
        LocalContainerEntityManagerFactoryBean factoryBean = new LocalContainerEntityManagerFactoryBean();
        factoryBean.setJpaVendorAdapter(vendorAdapter);

        Properties properties = new Properties();
        properties.setProperty("hibernate.dialect", "org.hibernate.dialect.MySQLDialect");
        properties.setProperty("hibernate.show_sql", "true");
        properties.setProperty("hibernate.hbm2ddl.auto", "update");
        properties.setProperty("hibernate.format_sql", "false");
        factoryBean.setJpaProperties(properties);

        //扫描实体类所在的包
        factoryBean.setPackagesToScan("com.easemob.oa.models.entity");
        factoryBean.setDataSource(getSecondDataSource());
        factoryBean.setPersistenceUnitName("turnfly");
        return factoryBean;
    }

 
}

Springboot替我们节省了这些配置操作,启动时会默认生成一个persistence-unit-name —— [default]

然而,若我们采用读取配置文件的方式读取该默认name ,

emf = Persistence.createEntityManagerFactory("default");

会发现取了个寂寞。。

仔细想想我们并没有配置文件,通过读取配置文件persistence-unit-name肯定不可取;而springboot已经默认注入了一个persistence-unit-name为default的bean到单例池中,我们直接取就ok;

正确的方式是新增一个工具类,类整体添加一个@Repository注解表示项目启动时被扫描,然后使用注解@PersistenceUnit注入EntityManagerFactory;

有关EntityManager、EntityManagerFactory相关请参考我另一篇文章:EntityManager、EntityManagerFactory详解

1、创建工具类EntityManagerHelper

ThreaLoad保证变量线程私有、通过set注入的形式给emf赋值;

package com.easemob.oa.persistence.config;

import org.springframework.stereotype.Repository;
import javax.persistence.*;
import java.util.logging.Level;

@Repository
public class EntityManagerHelper {
    // 实体化私有静态实体管理器变量emf
    //private static final EntityManagerFactory emf;
    private static EntityManagerFactory emf;


    // 实体化私有静态本地线程变量threadLocal
    private static final ThreadLocal<EntityManager> threadLocal;
    // 用来给两个变量赋初值的静态块
    static {
        //emf = Persistence.createEntityManagerFactory("zcxx");
        threadLocal = new ThreadLocal<EntityManager>();
    }

    @PersistenceUnit
    public void setEntityManager(EntityManagerFactory emf){
        EntityManagerHelper.emf = emf;
    }

    // 得到实体管理器的方法
    public static EntityManager getEntityManager() {
        EntityManager manager = threadLocal.get();
        if (manager == null || !manager.isOpen()) {
            manager = emf.createEntityManager();
            threadLocal.set(manager);
        }
        return manager;
    }

    // 关闭实体管理器的方法
    public static void closeEntityManager() {
        EntityManager em = threadLocal.get();
        threadLocal.set(null);
        if (em != null)
            em.close();
    }
    // 开始事务的方法
    public static void beginTransaction() {
        getEntityManager().getTransaction().begin();
    }
    // 提交事务的方法
    public static void commitTransaction() {
        getEntityManager().getTransaction().commit();
    }
    // 回滚事务的方法
    public static void rollback() {
        getEntityManager().getTransaction().rollback();
    }
    // 生成查找的方法
    public static Query createQuery(String query) {
        return getEntityManager().createQuery(query);
    }
    public static void log(String string, Level info, Object object){
        // TODO Auto-generated method stub
    }
}

2、BatchSaveRepositoryImpl.java、CallableResultVo.java

继承SimpleJpaRepository是注入em的一般用法,spring官网在自定义Repository接口时有介绍,在多线程时由于em是线程单独从emf获取的,可以不用继承SimpleJpaRepository(使用时可以自行取掉);

此外为了接收子List执行完毕带有id的返回值,子线程应实现Callable接口,可以构造一个CallableResultVo实体来接收;

CountDownLatch保证所有线程执行完毕才进行下一步操作;

CallableResultVo.java

package com.easemob.oa.models.callable;

import lombok.Data;

import java.util.List;

@Data
public class CallableResultVo<T> {
    List<T> result;
}

BatchSaveRepositoryImpl.java 

package com.easemob.oa.persistence.jpa.impl;

import com.easemob.oa.models.callable.CallableResultVo;
import com.easemob.oa.persistence.jpa.BatchSaveRepository;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.repository.NoRepositoryBean;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import javax.persistence.*;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.*;


/**
 * @Author turnflys
 * @Date 1/12/21 10:51 PM
 */
@NoRepositoryBean
@Slf4j
public class BatchSaveRepositoryImpl<T,ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BatchSaveRepository<T> {

    //每个线程分的数据量
    private final Integer BATCH_SIZE = 1500;

    //最大线程数
    private final Integer MAX_THREAD = 12;

    private static EntityManager em = null;

    public BatchSaveRepositoryImpl(JpaEntityInformation entityInformation, EntityManager entityManager) {
        super(entityInformation, entityManager);
        this.em = entityManager;
    }


    //@Async
    @Override
    @Transactional(propagation = Propagation.REQUIRES_NEW)
    public <S extends T> List<S> batchSave(Iterable<S> entities,Boolean isSave){
        //返回list数组,需要带id
        List<S> result = new ArrayList<>();

        List<S> lists = Lists.newArrayList();
        List<S> listsTmp;
        entities.forEach(lists::add);
        Integer listSize = lists.size();

        //构造线程池 - 默认最大MaxThread个线程同时执行,每个线程执行数据量BATCH_SIZE
        ExecutorService executorService = Executors.newFixedThreadPool(MAX_THREAD);
        //需要循环测次数,最后一次大概率不满足一个BATCH_SIZE
        Integer loopCount = listSize/BATCH_SIZE+1;
        //倒计时门闩 - await() 让线程等待,用countDown()消初始化数量。当数量等于0时线程唤醒
        CountDownLatch cdl = new CountDownLatch(loopCount);//使用计数器
        //创建FutureList,存储每一个线程返回的结果
        List<Future> futureSaveList = new ArrayList<>();
        List<Future> futureUpdateList = new ArrayList<>();

        //一共循环threadNum次
        for(int i = 0; i < loopCount; i++){
            if(i == loopCount-1){
                //走到头但不足一整次的部分
                log.info("------------------------------------------------拆分数据最后一部分下标范围:start - {}, end - {}.",i*BATCH_SIZE,listSize);
                listsTmp = lists.subList(i*BATCH_SIZE,listSize);
                PartSaveCallable<S> psc =  new PartSaveCallable<>(listsTmp,cdl,isSave);
                if(isSave){
                    futureSaveList.add(executorService.submit(psc));
                }else{
                    futureUpdateList.add(executorService.submit(psc));
                }

            }else{
                log.info("------------------------------------------------拆分数据下标范围:start - {}, end - {}.",i*BATCH_SIZE,(i+1)*BATCH_SIZE);
                listsTmp = lists.subList(i*BATCH_SIZE,(i+1)*BATCH_SIZE);
                PartSaveCallable<S> psc =  new PartSaveCallable<>(listsTmp,cdl,isSave);
                if(isSave){
                    futureSaveList.add(executorService.submit(psc));
                }else{
                    futureUpdateList.add(executorService.submit(psc));
                }
            }
        }


        try {
            //确保线程执行完
            cdl.await();
            List<Future> tempFutureList = isSave?futureSaveList:futureUpdateList;
            for(Future future : tempFutureList){
                //线程到这儿必定执行完了
                try {
                    Object res = future.get();
                    if(res != null){
                        CallableResultVo<S> crv = (CallableResultVo<S>) res;
                        result.addAll(crv.getResult());
                    }
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
        }finally {
            //执行完关闭线程池
            executorService.shutdown();
        }

        return result;
    }


    static <S> List<S> partBatchSave(Iterable<S> entities) {
        Iterator<S> iterator = entities.iterator();
        int index = 0;
        while (iterator.hasNext()){
            em.persist(iterator.next());
            index++;
            if (index % 500 == 0){
                em.flush();
                em.clear();
            }
        }
        if (index % 500 != 0){
            em.flush();
            em.clear();
        }
        List<S> lists = Lists.newArrayList();
        entities.forEach(lists::add);
        return lists;
    }


}

三、子线程实现类PartSaveCallable

注意Callable与Runnale的区别:

Callable可以在线程执行完毕时返回指定的值 且 可以向上抛出异常;

注意需要为每个线程新创建事务:@Transactional(propagation = Propagation.REQUIRES_NEW)

package com.easemob.oa.persistence.jpa.impl;

import com.easemob.oa.models.callable.CallableResultVo;
import com.easemob.oa.persistence.config.EntityManagerHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.transaction.annotation.Transactional;

import javax.persistence.EntityManager;
import javax.persistence.EntityTransaction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;


@Slf4j
public class PartSaveCallable<S> implements Callable {

    private List<S> lists;
    private CountDownLatch cdl;
    private Boolean isSave;

    public PartSaveCallable(List<S> lists, CountDownLatch cdl,Boolean isSave){
        this.lists = lists;
        this.cdl = cdl;
        this.isSave = isSave;
    }

    @Override
    @Transactional
    public CallableResultVo call() {
        CallableResultVo<S> result = new CallableResultVo<>();
        List<S> resultList = new ArrayList<>();
        String su = isSave?"插入":"更新";
        log.info("--------------线程"+Thread.currentThread()+"开始执行" + su + "操作!,当前cdl is -------:" + cdl.getCount() + " -----------------------------");

        //当前线程em
        EntityManager em = EntityManagerHelper.getEntityManager();
        EntityTransaction entityTransaction = em.getTransaction();


        try {
            log.info("-------------current EntityManager is :{} -------------",em);
            EntityManagerHelper.beginTransaction();
            //log.info("-------------list Size is :{} -------------",lists.size());
            for(S s : lists){
                if(isSave){
                    em.persist(s);
                }else{
                    em.merge(s);
                }
                resultList.add(s);
            }
            EntityManagerHelper.commitTransaction();
            EntityManagerHelper.closeEntityManager();
            result.setResult(resultList);
            log.info("-------------线程"+Thread.currentThread()+su+"完成,当前cdl(--) is -------:" + cdl.getCount() + " -------------");
            cdl.countDown();
        }catch (RuntimeException e){
            if (entityTransaction.isActive()) {
                entityTransaction.rollback();
            }

            log.error("发生错误 :{}", e);
        }
        return result;
    }
}

效果:

  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值