【分布式训练(4)】accelerator.sync_gradients 和 checkpointing 深入理解

【分布式训练 debug】VS Code Debug 技巧:launch.json实用参数
【分布式训练(2)】深入理解 DeepSpeed 的 ZeRO 内存优化策略 (三阶段的区别)
【分布式训练(3)】accelerator + deepspeed debug 报错 “Timed out waiting for debuggee to spawn“ 解决方法✅


accelerator.sync_gradients

sync_gradients(同步梯度)

  • sync_gradients 是一个在分布式训练中使用的策略,它涉及到在多个训练节点(或GPU)之间同步梯度。
  • 在分布式训练中,每个节点计算其自己的梯度(即损失函数对模型参数的偏导数),然后这些梯度需要被聚合以更新模型的全局参数。
  • sync_gradients 通常在每个优化步骤后执行,以确保所有节点上的模型参数保持一致。
# Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0
  • 在该代码片段中,accelerator.sync_gradients 可能是一个标志(flag),指示是否需要执行梯度同步。
  • 如果是这样,那么在每次优化步骤后,代码会更新进度条,记录日志,并可能执行其他清理或记录操作。
checkpointing(检查点)
  • Checkpointing 是一种保存训练过程中的关键状态的机制,以便在发生故障或为了恢复训练时可以从这些点重新开始。
  • 在深度学习中,检查点通常包括模型的参数(权重和偏置),优化器的状态(如动量项),以及可能的损失值和当前的迭代次数。
if global_step % args.checkpointing_steps == 0:
   if accelerator.is_main_process:
       # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
       if args.checkpoints_total_limit is not None:
           checkpoints = os.listdir(args.output_dir)
           checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
           checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

           # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
           if len(checkpoints) >= args.checkpoints_total_limit:
               num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
               removing_checkpoints = checkpoints[0:num_to_remove]

               logger.info(
                   f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
               )
               logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

               for removing_checkpoint in removing_checkpoints:
                   removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                   shutil.rmtree(removing_checkpoint)
  • 在该代码中,checkpointing_steps 是一个参数,指定每多少步进行一次检查点保存。当达到这个步数时,如果是主进程(accelerator.is_main_process),代码会执行以下操作:
  1. 检查是否达到了检查点总数的限制(checkpoints_total_limit)。
  2. 如果超过了限制,删除最旧的检查点,以确保不会超过最大限制。
  3. 创建一个新的检查点目录,并保存模型权重和优化器状态。
  4. 记录保存检查点的信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值