pytorch的ddp分布式训练

pytorch的ddp分布式训练

相对于单卡,多卡分布式训练能够:1.处理更大的模型的数据集,2.充分利用实验室硬件资源,3.以及最重要的加速训练速度,目前已经被广泛使用。本文用于简化pytorch的ddp训练代码,并解释每一步的作用。

导包

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
  1. torch.distributed包是用于支持分布式深度学习的核心包。它提供了其中包括初始化分布式环境、同步张量、发送和接收消息等功能函数。
  2. torch.multiprocessing是PyTorch中的多进程包,用于启动和管理多个进程。
  3. DistributedDataParallel是PyTorch的分布式训练工具,用于复制模型到不同GPU上,训练并同步梯度和参数。
  4. torch.utils.data.distributed.DistributedSampler是数据采样器(sampler)类,确保不同GPU或节点上的模型训练使用不同的数据子集,避免数据重复使用。

初始化进程组

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank = rank, world_size = world_size)
  1. MASTER_ADDRMASTER_PORT是两个待设置的环境变量,指定进程主节点的地址和端口

  2. nccl为通信后端,rank为当前分配的进程的标识,world_size表示总进程数

划分和分发数据

sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
batch_size = math.ceil(self.batch_size / self.world_size)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, shuffle=False, pin_memory=True)
  1. sampler的rank和num_replicas分别用于指定当前进程和总参与的进程数
  2. batch_size为设置数/总参与进程数,如设置为8,有4个GPU参与,每个GPU负责2个batch
  3. 因为sampler已经打乱了数据顺序,dataloader的参数shuffle=False

包装为DDP模型

ddp_kwargs = {'device_ids': [rank], 'find_unused_parameters': True}
model_ddp = DDP(model, **ddp_kwargs)
  1. device_ids 参数指定了在哪个GPU上运行DDP模型
  2. find_unused_parameters 参数告诉DDP是否查找未使用的参数。模型中有一部分参数不参与更新则可能会报警告,但如果你确定这个改动,则不需要管它。如GAN等模型训练辨别器时,生成器参数只参与前向传播,不参与更新。

启动多卡训练

mp.spawn(train,args=(world_size, model_args, num_train_steps, name, seed),nprocs=world_size,join=True)
  1. 第一个参数为训练函数train,args为传递给train函数的参数,同时还会隐式传递一个进程标识参数rank给train函数
  2. nprocs参数指定要启动的进程数量,通常等于分布式环境中的总进程数world_size
  3. join参数用于指定是否要等待所有进程完成后再继续执行后续代码。设置为True表示要等待所有进程完成,然后再继续执行主进程的代码

释放资源

dist.destroy_process_group() # 训练结束释放进程组和其他资源

注意点:

  1. model_ddp = DDP(model, **ddp_kwargs)这一步会将模型进一步包装为ddp模型,如果想要调用原来model中的方法需要加.module,如model_ddp.module.func(...)

  2. 同上,参数保存时也会在key前加module,可以在写入时或读取模型时候手动修改

    # 载入参数时修改state_dict的key值
    ck = {k.replace('module.', ''): v for k, v in torch.load('model.pth')['state_dict'].items()}
    model.load_state_dict(ck)
    
  3. 使用梯度累积+多卡训练时,只有更新参数时需要梯度同步。故在梯度累积阶段,使用model_ddp.no_sync()主动不进行梯度同步

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值