pytorch 分布式训练

本文介绍了分布式训练中的数据并行,其可将大量数据分配到不同机器/卡上,加快训练速度。还阐述了All Reduce和Server - Work两种架构,前者通过单节点上reduce + broadcast操作完成,Pytorch用Ring AllReduce;后者中server负责数据分发等,worker负责载入数据和前向计算。此外,介绍了Pytorch支持的分布式方式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

分布式训练介绍

数据并行当任务数据的量非常大,但是模型本身规模不大(显存可以放得下)的时候,可以把数据按照一定的规则分配到不同的机器/卡上分工合作。同时,每个机器/卡都加载同一个模型,各自单独地根据输入数据计算梯度,并对梯度按照一定规则进行通信,这即梯度的汇总更新。接下来就是对模型参数的更新,确保每个iter结束的时候,所有卡上的参数一致。当然,这也要求初始化的时候每张卡的参数一致。数据并行的目的,就是希望在同样的时间内:更多的数据同时训练,加快了同等量级数据的训练完成时间, 相当于增大 batch_size,梯度下降方向会更贴合数据集整体的实际情况。

数据并行流程图如下:

All Reduce 架构:在这里插入图片描述
每个GPU计算完结果后,需要汇总更新,就是图中红色框的All-Reduce就是达成数据并行的重要操作。

Allreduce操作是常用通信模式的一种。在说Allreduce之前,先看下reduce。reduce称为规约运算,是一系列运算操作的统称,细分来说包括SUM、MIN、MAX、PROD、LOR等。reduce意为减少/精简,因为其操作在每个进程上获取一个输入元素数组,通过执行操作后,将得到精简的更少的元素。例如下面的Reduce sum:

在这里插入图片描述
可理解为多个GPU的结果汇总到一块,再通过某种方式去得到一个最终的梯度回传的结果(多GPU共享一个结果)

在这里插入图片描述
从图中可以看出,all reduce操作可通过单节点上reduce+broadcast操作完成

Allreduce方式也分多种,目前Parrots和Pytorch框架里用的都是Ring AllReduce 这种方式,而在进程间传输的就是模型参数的梯度。

在训练过程的表现可以如下图所示。
在这里插入图片描述
这个闭环的Ring可以 这张图表示。多卡训练的每张卡对应一个worker,所有worker形成一个闭环,接受上家的梯度,再把累加好的梯度传给下家。最终计算完成后更新整个环上worker的梯度(这样所有worker上的梯度就相等了),然后每张卡各自求梯度反向传播。
在这里插入图片描述
第二种架构: Server-Work架构
这里的server和worker就对应了我们训练用到的多卡,通常DP接口中输入参数的第一张卡被用作「server」,其他卡号就是对应的workers
在这里插入图片描述
Server 负责:

  • 将数据切成等量大小
  • 从 device0分发到其他卡上并把模型复制到各个卡上
  • 前向计算损失和反向计算梯度
  • 更新参数
  • 在下一个iter开始的时候把参数更新到worker上

worker负责:

  • 载入server分发过来的数据和模型
  • 模型前向计算

PS架构在训练中的劣势显而易见:通信的成本随着卡数的增加而增加; server负担过重,容易显存爆炸。
还有的设计是只把server用作梯度更新,进一步减轻负担,比如下面这个过程:
在这里插入图片描述
在Pytorch支持的分布式方式中:
DP使用的是Parameter Server(PS)架构;
DDP官网推荐,性能优于DP,
All-reduce架构;apex是英伟达出品的工具库,封装了DistributedDataParallel,All-Reduce架构。

作者:九点澡堂子
链接:https://zhuanlan.zhihu.com/p/366253646
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

### PyTorch 分布式训练教程和最佳实践 #### 一、背景介绍 随着数据集规模的增长和模型复杂性的增加,单台机器的计算能力逐渐难以满足需求。分布式训练成为解决这一问题的有效途径之一。PyTorch作为一个灵活且强大的深度学习框架,在支持分布式训练方面表现优异[^1]。 #### 二、核心概念与联系 在深入探讨之前,有必要先理解几个关键术语: - **进程组(Process Group)**: 是指参与同一轮同步操作的一群进程; - **Rank (排名)**: 表示当前进程中在整个集群中的唯一编号; - **World Size (世界大小)**: 即整个集群内所有可用设备的数量总和; 这些概念对于构建有效的分布式系统至关重要[^3]。 #### 三、安装依赖项 要开始使用PyTorch进行分布式训练,首先需要确保环境中已正确安装必要的库文件和其他软件组件。这通常包括但不限于Python版本兼容性检查、CUDA驱动程序更新等准备工作。具体步骤可参照官方指南完成相应设置[^2]。 #### 四、初始化环境变量 成功部署好基础架构之后,则需配置一些重要的环境参数来指定运行模式: ```bash export NCCL_DEBUG=INFO export MASTER_ADDR="localhost" export MASTER_PORT=12355 ``` 上述命令分别用于开启调试日志记录功能以便于排查错误信息,并定义主节点地址及其监听端口号以供其他子节点连接建立通信链路之用。 #### 五、编写并执行简单的DDP程序 下面给出一段基于`torch.nn.parallel.DistributedDataParallel`(简称 DDP )实现的基础代码样例: ```python import torch from torch import nn, optim from torchvision import datasets, transforms def main(rank, world_size): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,sampler=sampler,num_workers=4,pin_memory=True) if __name__ == '__main__': # 假设我们有两块GPU卡可供分配给两个独立的工作线程 WORLD_SIZE = 2 mp.spawn(main,args=(WORLD_SIZE,),nprocs=WORLD_SIZE,join=True) ``` 此脚本展示了如何创建一个简易的数据加载器,并将其封装成适合分布式的迭代对象形式传递给下游模块处理。 #### 六、探索更复杂的特性 除了基本的功能之外,DDP还提供了一系列高级选项用来增强系统的灵活性与鲁棒性。例如可以通过调整`find_unused_parameters`标志位控制是否自动检测未使用的张量从而减少不必要的内存占用;或是借助内置API接口自定义梯度聚合策略等等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值