Pytorch分布式训练

本文介绍了PyTorch中用于多GPU训练的两种方法:DataParallel和DistributedDataParallel。DataParallel通过单进程多线程实现数据并行,但受到GIL限制效率不高。相比之下,DistributedDataParallel采用多进程方式,避免了线程效率问题,通过环形梯度同步提高并行效率。文章详细展示了DistributedDataParallel的使用步骤,并提醒在多进程环境中需要注意的print输出和模型保存等问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用单机单卡训练模型的时代已经过去,单机多卡已经成为主流配置。如何最大化发挥多卡的作用呢?本文介绍Pytorch中的DistributedDataParallel方法。

1. DataParallel

其实Pytorch早就有数据并行的工具DataParallel,它是通过单进程多线程的方式实现数据并行的。

简单来说,DataParallel有一个参数服务器的概念,参数服务器所在线程会接受其他线程传回来的梯度与参数,整合后进行参数更新,再将更新后的参数发回给其他线程,这里有一个单对多的双向传输。因为Python语言有GIL限制,所以这种方式并不高效,比方说实际上4卡可能只有2~3倍的提速。

2. DistributedDataParallel

Pytorch目前提供了更加高效的实现,也就是DistributedDataParallel。从命名上比DataParallel多了一个分布式的概念。首先 DistributedDataParallel是能够实现多机多卡训练的,但考虑到大部分的用户并没有多机多卡的环境,本篇博文主要介绍单机多卡的用法。

从原理上来说,DistributedDataParallel采用了多进程,避免了python多线程的效率低问题。一般来说,每个GPU都运行在一个单独的进程内,每个进程会独立计算梯度。

同时DistributedDataParallel抛弃了参数服务器中一对多的传输与同步问题,而是采用了环形的梯度传递,这里引用知乎上的图例。这种环形同步使得每个GPU只需要和自己上下游的GPU进行进程间的梯度传递,避免了参数服务器一对多时可能出现的信息阻塞。

3. DistributedDataParallel示例

下面给出一个非常精简的单机多卡示例,分为六步实现单机多卡训练。

第一步,首先导入相关的包。

import argparse
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

第二步,加一个参数,local_rank。这比较好理解,相当于就是告知当前的程序跑在那一块GPU上,也就是下面的第三行代码。local_rank是通过pytorch的一个启动脚本传过来的,后面将说明这个脚本是啥。最后一句是指定通信方式,这个选nccl就行。

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()

torch.cuda.set_device(args.local_rank)

dist.init_process_group(backend='nccl')

第三步,包装Dataloader。这里需要的是将sampler改为DistributedSampler,然后赋给DataLoader里面的sampler。

为什么需要这样做呢?因为每个GPU,或者说每个进程都会从DataLoader里面取数据,指定DistributedSampler能够让每个GPU取到不重叠的数据。

读者可能会比较好奇,在下面指定了batch_size为24,这是说每个GPU都会被分到24个数据,还是所有GPU平分这24条数据呢?答案是,每个GPU在每个iter时都会得到24条数据,如果你是4卡,一个iter中总共会处理24*4=96条数据。

train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)

trainloader = torch.utils.data.DataLoader(my_trainset,batch_size=24,num_workers=4,sampler=train_sampler)

第四步,使用DDP包装模型。device_id仍然是args.local_rank。

model = DDP(model, device_ids=[args.local_rank])

第五步,将输入数据放到指定GPU。后面的前后向传播和以前相同。

for imgs,labels in trainloader:
        
    imgs=imgs.to(args.local_rank)
    labels=labels.to(args.local_rank)
    
    optimizer.zero_grad()
    output=net(imgs)
    loss_data=loss(output,labels)
    loss_data.backward()
    optimizer.step()

第六步,启动训练。torch.distributed.launch就是启动脚本,nproc_per_node是GPU数。

python -m torch.distributed.launch --nproc_per_node 2 main.py

通过以上六步,我们就让模型跑在了单机多卡上。是不是也没有那么麻烦,但确实要比DataParallel复杂一些,考虑到加速效果,不妨试一试。

4. DistributedDataParallel注意点

DistributedDataParallel是多进程方式执行的,那么有些操作就需要小心了。如果你在代码中写了一行print,并使用4卡训练,那么你将会在控制台看到四行print。我们只希望看到一行,那该怎么做呢?

像下面一样加一个判断即可,这里的get_rank()得到的是进程的标识,所以输出操作只会在进程0中执行。

if dist.get_rank() == 0:
    print("hah")

你会经常需要dist.get_rank()的。因为有很多操作都只需要在一个进程里执行,比如保存模型,如果不加以上判断,四个进程都会写模型,可能出现写入错误;另外load预训练模型权重时,也应该加入判断,只load一次;还有像输出loss等一些场景。

【参考】 [原创][深度][PyTorch] DDP系列第一篇:入门教程 - 知乎

 

### PyTorch 分布式训练教程和最佳实践 #### 一、背景介绍 随着数据集规模的增长和模型复杂性的增加,单台机器的计算能力逐渐难以满足需求。分布式训练成为解决这一问题的有效途径之一。PyTorch作为一个灵活且强大的深度学习框架,在支持分布式训练方面表现优异[^1]。 #### 二、核心概念与联系 在深入探讨之前,有必要先理解几个关键术语: - **进程组(Process Group)**: 是指参与同一轮同步操作的一群进程; - **Rank (排名)**: 表示当前进程中在整个集群中的唯一编号; - **World Size (世界大小)**: 即整个集群内所有可用设备的数量总和; 这些概念对于构建有效的分布式系统至关重要[^3]。 #### 三、安装依赖项 要开始使用PyTorch进行分布式训练,首先需要确保环境中已正确安装必要的库文件和其他软件组件。这通常包括但不限于Python版本兼容性检查、CUDA驱动程序更新等准备工作。具体步骤可参照官方指南完成相应设置[^2]。 #### 四、初始化环境变量 成功部署好基础架构之后,则需配置一些重要的环境参数来指定运行模式: ```bash export NCCL_DEBUG=INFO export MASTER_ADDR="localhost" export MASTER_PORT=12355 ``` 上述命令分别用于开启调试日志记录功能以便于排查错误信息,并定义主节点地址及其监听端口号以供其他子节点连接建立通信链路之用。 #### 五、编写并执行简单的DDP程序 下面给出一段基于`torch.nn.parallel.DistributedDataParallel`(简称 DDP )实现的基础代码样例: ```python import torch from torch import nn, optim from torchvision import datasets, transforms def main(rank, world_size): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,sampler=sampler,num_workers=4,pin_memory=True) if __name__ == '__main__': # 假设我们有两块GPU卡可供分配给两个独立的工作线程 WORLD_SIZE = 2 mp.spawn(main,args=(WORLD_SIZE,),nprocs=WORLD_SIZE,join=True) ``` 此脚本展示了如何创建一个简易的数据加载器,并将其封装成适合分布式的迭代对象形式传递给下游模块处理。 #### 六、探索更复杂的特性 除了基本的功能之外,DDP还提供了一系列高级选项用来增强系统的灵活性与鲁棒性。例如可以通过调整`find_unused_parameters`标志位控制是否自动检测未使用的张量从而减少不必要的内存占用;或是借助内置API接口自定义梯度聚合策略等等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值