Pytorch分布式训练杂记1

  1. dist.broadcast_object_list

    • 功能:这个函数用于在分布式训练中,将某个进程上的对象列表广播给所有其他进程。它可以确保所有进程在训练开始前或训练过程中共享相同的数据,尤其是在多进程场景中每个进程都需要访问相同的对象时。
    • 用法
      import torch.distributed as dist
      # 假设我们有一个对象列表 obj_list
      obj_list = [some_object]
      # 从 rank 0 广播对象列表给所有其他进程
      dist.broadcast_object_list(obj_list, src=0)
      
      • obj_list:要广播的对象列表。
      • src:源进程的 rank(一般是 0),该进程上的对象会广播给其他进程。
  2. dist.barrier()

    • 功能:这个函数是一个同步机制。它会让所有进程在调用此函数后,必须等待其他进程也到达这个屏障点才能继续执行。通常用于确保所有进程都完成了某些步骤(比如数据加载)后再继续训练,避免进程之间的不同步问题。
    • 用法
      import torch.distributed as dist
      # 阻塞所有进程,直到每个进程都到达这个屏障
      dist.barrier()
      
  3. dist.get_rank()

    • 功能:这个函数返回当前进程的 rank,即进程在分布式训练中的编号。每个进程在初始化时都会被分配一个唯一的 rank,rank 为 0 的进程通常被称为主进程(主节点),负责某些全局操作,比如保存模型。
    • 用法
      import torch.distributed as dist
      # 获取当前进程的 rank
      rank = dist.get_rank()
      print(f"Current rank: {rank}")
      
  4. ddp = int(os.environ.get("RANK", -1)) != -1

    • 功能:这一行代码通过检查环境变量 RANK 来判断当前是否处于分布式环境中。RANK 环境变量在分布式训练中通常由框架自动设置,表示进程的编号。如果 RANK 存在且不是 -1,则表明当前进程正在参与分布式训练。
    • 解释
      import os
      ddp = int(os.environ.get("RANK", -1)) != -1
      if ddp:
          print("We are in a distributed environment.")
      else:
          print("Not in a distributed environment.")
      

    这段代码通过检查 RANK 是否存在(并且不是 -1)来判断是否处于分布式训练模式下,ddp 变量的值为 True 表示是分布式环境,False 表示不是。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值