JAVA面试题分享三十:多线程事物回滚?

一、背景

日常项目中,经常会出现一个场景,同时批量插入数据库数据,由于逻辑复杂或者其它原因,我们无法使用sql进行批量插入。串行效率低,耗时长,为了提高效率,这个时候我们首先想到多线程并发插入,但是如何控制事务呢 

二、实现效果

开启多条子线程,并发插入数据库

当其中一条线程出现异常,或者处理结果为非预期结果,则全部线程均回滚

@Service
public class CompanyUserBatchServiceImpl implements CompanyUserBatchService {
    private static final Logger logger = LoggerFactory.getLogger(CompanyUserBatchServiceImpl.class);

    @Autowired
    private CompanyUserService companyUserService;

    @Override
    public ReturnData addNewCurrentCompanyUsers(String params) {
        logger.info("addNewCompanyUsers 新增参保人方法");
        logger.info(">>>>>>>>>>>>参数:{}", params);
        ReturnData rd = new ReturnData();
        rd.setRetCode(CommonConstants.RETURN_CODE_FAIL);
        if (StringUtils.isBlank(params)) {
            rd.setMsg("入参为空!");
            logger.info(">>>>>>入参为空。");
            return rd;
        }

        List<CompanyUserResultVo> companyUsers;
        try {
            companyUsers = JSONObject.parseArray(params, CompanyUserResultVo.class);
        } catch (Exception e) {
            logger.info(">>>>>>>>>入参格式有误: {}", e);
            rd.setMsg("入参格式有误!");
            return rd;
        }


        //每条线程最小处理任务数
        int perThreadHandleCount = 1;
        //线程池的最大线程数
        int nThreads = 10;
        int taskSize = companyUsers.size();

        if (taskSize > nThreads * perThreadHandleCount) {
            perThreadHandleCount = taskSize % nThreads == 0 ? taskSize / nThreads : taskSize / nThreads + 1;
            nThreads = taskSize % perThreadHandleCount == 0 ? taskSize / perThreadHandleCount : taskSize / perThreadHandleCount + 1;
        } else {
            nThreads = taskSize;
        }

        logger.info("批量添加参保人taskSize: {}, perThreadHandleCount: {}, nThreads: {}", taskSize, perThreadHandleCount, nThreads);
        CountDownLatch mainLatch = new CountDownLatch(1);
        //监控子线程
        CountDownLatch threadLatch = new CountDownLatch(nThreads);
        //根据子线程执行结果判断是否需要回滚
        BlockingDeque<Boolean> resultList = new LinkedBlockingDeque<>(nThreads);
        //必须要使用对象,如果使用变量会造成线程之间不可共享变量值
        RollBack rollBack = new RollBack(false);
        ExecutorService fixedThreadPool = Executors.newFixedThreadPool(nThreads);

        List<Future<List<Object>>> futures = Lists.newArrayList();
        List<Object> returnDataList = Lists.newArrayList();
        //给每个线程分配任务
        for (int i = 0; i < nThreads; i++) {
            int lastIndex = (i + 1) * perThreadHandleCount;
            List<CompanyUserResultVo> companyUserResultVos = companyUsers.subList(i * perThreadHandleCount, lastIndex >= taskSize ? taskSize : lastIndex);
            AddNewCompanyUserThread addNewCompanyUserThread = new AddNewCompanyUserThread(mainLatch, threadLatch, rollBack, resultList, companyUserResultVos);
            Future<List<Object>> future = fixedThreadPool.submit(addNewCompanyUserThread);
            futures.add(future);
        }

        /** 存放子线程返回结果. */
        List<Boolean> backUpResult = Lists.newArrayList();
        try {
            //等待所有子线程执行完毕
            boolean await = threadLatch.await(20, TimeUnit.SECONDS);
            //如果超时,直接回滚
            if (!await) {
                rollBack.setRollBack(true);
            } else {
                logger.info("创建参保人子线程执行完毕,共 {} 个线程", nThreads);
                //查看执行情况,如果有存在需要回滚的线程,则全部回滚
                for (int i = 0; i < nThreads; i++) {
                    Boolean result = resultList.take();
                    backUpResult.add(result);
                    logger.debug("子线程返回结果result: {}", result);
                    if (result) {
                        /** 有线程执行异常,需要回滚子线程. */
                        rollBack.setRollBack(true);
                    }
                }
            }
        } catch (InterruptedException e) {
            logger.error("等待所有子线程执行完毕时,出现异常");
            throw new SystemException("等待所有子线程执行完毕时,出现异常,整体回滚");
        } finally {
            //子线程再次开始执行
            mainLatch.countDown();
            logger.info("关闭线程池,释放资源");
            fixedThreadPool.shutdown();
        }

        /** 检查子线程是否有异常,有异常整体回滚. */
        for (int i = 0; i < nThreads; i++) {
            if (CollectionUtils.isNotEmpty(backUpResult)) {
                Boolean result = backUpResult.get(i);
                if (result) {
                    logger.info("创建参保人失败,整体回滚");
                    throw new SystemException("创建参保人失败");
                }
            } else {
                logger.info("创建参保人失败,整体回滚");
                throw new SystemException("创建参保人失败");
            }
        }

        //拼接结果
        try {
            for (Future<List<Object>> future : futures) {
                returnDataList.addAll(future.get());
            }
        } catch (Exception e) {
            logger.info("获取子线程操作结果出现异常,创建的参保人列表: {} ,异常信息: {}", JSONObject.toJSONString(companyUsers), e);
            throw new SystemException("创建参保人子线程正常创建参保人成功,主线程出现异常,回滚失败");
        }

        rd.setRetCode(CommonConstants.RETURN_CODE_SUCCESS);
        rd.setData(returnDataList);
        return rd;
    }

    public class AddNewCompanyUserThread implements Callable<List<Object>> {
        /**
         * 主线程监控
         */
        private CountDownLatch mainLatch;
        /**
         * 子线程监控
         */
        private CountDownLatch threadLatch;
        /**
         * 是否回滚
         */
        private RollBack rollBack;
        private BlockingDeque<Boolean> resultList;
        private List<CompanyUserResultVo> taskList;

        public AddNewCompanyUserThread(CountDownLatch mainLatch, CountDownLatch threadLatch, RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
            this.mainLatch = mainLatch;
            this.threadLatch = threadLatch;
            this.rollBack = rollBack;
            this.resultList = resultList;
            this.taskList = taskList;
        }

        @Override
        public List<Object> call() {
            //为了保证事务不提交,此处只能调用一个有事务的方法,spring 中事务的颗粒度是方法,只有方法不退出,事务才不会提交
            return companyUserService.addNewCompanyUsers(mainLatch, threadLatch, rollBack, resultList, taskList);
        }

    }

    public class RollBack {
        private Boolean isRollBack;

        public Boolean getRollBack() {
            return isRollBack;
        }

        public void setRollBack(Boolean rollBack) {
            isRollBack = rollBack;
        }

        public RollBack(Boolean isRollBack) {
            this.isRollBack = isRollBack;
        }
    }

public List<Object> addNewCompanyUsers(CountDownLatch mainLatch, CountDownLatch threadLatch, CompanyUserBatchServiceImpl.RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
        List<Object> returnDataList = Lists.newArrayList();
        Boolean result = false;
        logger.debug("线程: {}创建参保人条数 : {}", Thread.currentThread().getName(), taskList.size());
        try {
            for (CompanyUserResultVo companyUserResultVo : taskList) {
                ReturnData returnData = addSingleCompanyUser(companyUserResultVo);
                if (returnData.getRetCode() == CommonConstants.RETURN_CODE_FAIL) {
                    result = true;
                }
                returnDataList.add(returnData.getData());
            }
            //Exception 和 Error 都需要抓
        } catch (Throwable throwable) {
            throwable.printStackTrace();
            logger.info("线程: {}创建参保人出现异常: {} ", Thread.currentThread().getName(), throwable);
            result = true;
        }

        resultList.add(result);
        threadLatch.countDown();
        logger.info("子线程 {} 计算过程已经结束,等待主线程通知是否需要回滚", Thread.currentThread().getName());

        try {
            mainLatch.await();
            logger.info("子线程 {} 再次启动", Thread.currentThread().getName());
        } catch (InterruptedException e) {
            logger.error("批量创建参保人线程InterruptedException异常");
            throw new SystemException("批量创建参保人线程InterruptedException异常");
        }

        if (rollBack.getRollBack()) {
            logger.error("批量创建参保人线程回滚, 线程: {}, 需要更新的信息taskList: {}",
                    Thread.currentThread().getName(),
                    JSONObject.toJSONString(taskList));
            logger.info("子线程 {} 执行完毕,线程退出", Thread.currentThread().getName());
            throw new SystemException("批量创建参保人线程回滚");
        }

        logger.info("子线程 {} 执行完毕,线程退出", Thread.currentThread().getName());
        return returnDataList;
    }

三、方案

思想就是使用两个CountDownWatch实现子线程的二段提交

步骤:

1、主线程将任务分发给子线程,然后使用 boolean await = threadLatch.await(20,TimeUnit.SECONDS);阻塞主线程,等待所有子线程处理向数据库中插入的业务

2、使用threadLatch.countDown();释放子线程锁定,同时使用mainLatch.await();阻塞子线程,将程序的控制权交还给主线程。

3、主线程检查子线程执行插入数据库的结果,若有非预期结果出现,主线程标记状态告知子线程回滚,然后使用mainLatch.countDown();将程序控制权再次交给子线程,子线程检测回滚标志,判断是否回滚。

4、子线程执行结束,主线程拼接处理结果,响应给请求方

整个过程类似于GC的标记-清除过程(串行的垃圾收集器)

四、思路总结

多线程事务回滚是指在多线程环境下进行数据库操作时,如果其中一个线程的事务执行失败,需要将整个事务回滚到执行前的状态,以保证数据库的一致性和完整性。

在多线程环境下,事务的并发执行可以提高系统的性能和响应速度,但同时也增加了事务失败的风险。如果其中一个线程的事务执行失败,可能会导致整个数据库处于不一致的状态,这时就需要进行事务回滚操作。

事务回滚的实现方式通常是通过数据库管理系统提供的事务控制语句来实现的,例如SQL中的ROLLBACK语句。在多线程环境下,需要在编程中控制事务的提交和回滚,以保证事务的一致性和完整性。

同时,在多线程环境下进行事务回滚时需要注意以下几点:

  1. 事务的边界:需要明确事务的边界,即哪些操作属于同一个事务,以便在出现错误时进行回滚。
  2. 并发控制:在多线程环境下,需要控制不同线程对同一事务的并发访问,避免出现冲突和竞争条件。
  3. 错误处理:在编程时需要考虑到可能出现的错误情况,并相应地进行处理,如回滚事务、记录日志等。
  4. 性能优化:事务回滚可能会影响系统性能,因此在进行多线程事务处理时,需要注意性能优化,如减少不必要的数据库访问、优化事务处理等。

总之,多线程事务回滚是保障数据库一致性和完整性的重要手段,需要在编程时进行合理的控制和处理。

  • 11
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

之乎者也·

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值