【主要参考资料】
一、总体概览
torch.distributed 可以分为三个主要部分:DDP、RPC和c10d。:
1.分布式数据并行训练 Distributed Data-Parallel Training(DDP):用于单一程序的数据并行训练。首先为每个进程复制一份模型,不同模型副本处理训练集中的不同数据( a different set of input data samples)。
DDP使用gradient communication以保持模型副本同步,并将其与梯度计算进行重叠(这样可以加快训练)。
2.基于RPC的分布式训练 RPC-Based Distributed Training (RPC) : 支持无法进行数据并行训练的一些general training structures,例如分布式流水线并行/参数服务器范式/DDP与其他训练范式的组合。
有时候我们希望让一部分进程中的模型完成A操作,让另一部分进程中的模型完成B操作。此时就可以使用RPC实现不同进程间模型的通信和协作,完成更复杂的任务。
3.集体通信库 Collective Commmunication (c10d) library: 支持组内进行进程通信,发送张量。它不但提供了集体通信API(如all_reduce和all_gather),也提供了P2P通信API(如send和isend)。
这个API库较为底层,前面提到的DDP,使用的是c10d的集体通信API,RPC使用的是c10d的P2P API。
当使用DDP进行分布式训练时,通常每次迭代都会同步梯度,如果我们想在一系列迭代后再进行同步,一种用法使用c10d是进行分布式参数平均,即计算所有模型的参数的平均值,而不是使用DDP来传达梯度。这样可以将通信与计算解耦,并允许对通信内容进行更精细的控制,但是另一方面,却放弃了DDP提供的性能优化。
在分布式计算中,可能不同的设备之间的通信方式不同,也可以使用c10d进行更精细的通信控制。
API使用教程见Writing Distributed Applications with PyTorch — PyTorch Tutorials 2.0.1+cu117 documentation
根据并行训练的难易程度,可以分为:
1.单一机器单一GPU训练
2.单一机器多GPU分布式训练 DataParallel
3.单一机器多GPU分布式训练 DistributedDataParallel
4.多台机器多GPU分布式训练
5.torch.distributed.elastic:允许资源在训练过程中弹性增减
如果预期训练过程中会出现错误(如内存溢出),或者在训练过程中资源会动态的加入或离开,可以使用torch.distributed.elastic来启动分布式训练。
Data-parallel training also works with Automatic Mixed Precision (AMP).
二、主要API
数据并行的思路:将同一批次的数据划分为不同子批次,在不同GPU上处理。梯度在所有GPU中同步。
首先理解DataParallel和DistributedDataParallel的区别:
方面 | DataParallel | DistributedDataParallel |
数据处理方式 | 单进程多线程。容易碰到GIL连接问题。每次前向传播前都需要复制模型到其他GPU。 | 每个GPU都有独立的Python进程。相比于DataParallel,需要多一个init_process_group操作。因为是多进程操作,所以没有GIL问题。且模型只在创建之初复制,而不是每次前向传播都要复制。 |
效率 | 可能会在主设备(通常是GPU 0)上遇到瓶颈,因为它需要收集所有GPU的输出并计算损失。 | 更加高效,因为每个进程独立运行,并在计算完成后同步梯度。 |
使用范围 | 单一机器,多GPU。 | 多机器,多GPU。 |
假设我们有一台机器,上面有4个GPU,需要处理的一个批次数据由400个样本组成。
在使用DataParallel
时,每个GPU可能处理100个样本。然后在主GPU上计算损失并同步梯度。
而使用DistributedDataParallel
时,会有4个进程,每个进程在一个GPU上运行并处理100个样本,然后梯度在四个进程之间同步。
DDP is shipped with several performance optimization technologies. For a more in-depth explanation, refer to this paper (VLDB’20).
打算本轮先学习DDP,后面学有余力学习torch.distributed.elastic(资源动态利用)和RPC(除了数据并行外还有许多可以并行的范式,RPC可以实现如参数并行、强化学习多模块并行等。)
三、torch.nn.parallel.DistributedDataParallel
学习资料
-
DDP notes offer a starter example and some brief descriptions of its design and implementation. If this is your first time using DDP, start from this document.
-
Getting Started with Distributed Data Parallel explains some common problems with DDP training, including unbalanced workload, checkpointing, and multi-device models. Note that, DDP can be easily combined with single-machine multi-device model parallelism which is described in the Single-Machine Model Parallel Best Practices tutorial.
-
The Launching and configuring distributed data parallel applications document shows how to use the DDP launching script.
-
The Shard Optimizer States With ZeroRedundancyOptimizer recipe demonstrates how ZeroRedundancyOptimizer helps to reduce optimizer memory footprint.
-
The Distributed Training with Uneven Inputs Using the Join Context Manager tutorial walks through using the generic join context for distributed training with uneven inputs.
四、DDP notes
以下是一个简单的使用 torch.nn.parallel.DistributedDataParallel 的例子。使用torch.nn.Linear作为本地模型,并封装为DDP模型,在DDP模型上分别进行一次正向传播、方向传播、优化器优化,从而更新本地模型的参数,同时其他进程上的参数保持一致。
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
main()