Pytorch 分布式数据并行训练
简要说明:本例使用torch.nn.Linear()作为网络模型,并将其用 torch.nn.parallel.DistributedDataParallel() (简称,DDP)包装,随后进行一次前向传播,一次反向传播,DDP上优化器的一次优化,最后,局部网络模型得到优化,并同步到不同的卡中。
主要方法:
torch.nn.parallel.DistributedDataParallel()
上例子:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
下面将介绍内部实现,不关心的可以关网页了(呀~~希望点赞没事)
一次迭代过程细节
- 初始化:DDP依靠c10d的 ProcessGroup 进行通信,因此,在构造DDP之前,必须先实例化ProcessGroup。
- 构造DDP:DDP的构造函数需要引用本地模型,并从rank 0的进程广播模型参数至组内其他进程,从而保证所有模型的参数来源一致;随后DDP进程创建一个局部Reducer,用于确保反向传播过程中各设备上参数梯度一致。同时,为了保证通信效率,Reducer将模型参数梯度统一组织到Bucket中,从而一次性完成参数更新,其中Bucket容量可以在构造DDP时通过bucket_cap_mb完成参数设置。参数梯度与Bucket的映射关系是在DDP构造时依据模型参数大小和Bucket容量建立的,而模型参数大小则是由给定模型的参数,即Model.parameters(),以逆序形式分配到Bucket中,之所以是逆序,是DDP为了确保反向传播时,参数是顺序更新的。除了Bucket,Reducer在构造阶段还未每个参数注册了一个钩子(hook),一旦梯度准备好,就会触发,从而实现快速反向传播。
- 前向传播:前向传播过程中,DDP获取输入并将其传入本地模型,随后,当find_unused_parameters为True时,分析本地模型的输出。该模式下允许模型参数在子计算图上完成反向传播,DDP通过从模型的输出中,遍历计算图,标记并还原未使用的参数,从而获取反向传播过程中涉及的参数。在反向传播过程中,Reducer不仅等待未读参数,还需要还原所有的buckets。将参数梯度标记为ready并不能帮助DDP跳过当前的bucket,但可以防止反传中DDP等待不存在的梯度。注意,遍历计算图会带来额外的开销,因此只需要在必要时,将find_unused_parameters置为True。
- 反向传播:反向传播backward函数直接在超出DDP控制的损失张量上调用,DDP使用在构造时注册的自动加载钩子来触发梯度同步。当某参数的梯度已是ready状态,该参数的DDP钩子将会触发,随后,DDP将会将该参数梯度标记为ready,当一个Bucket中的梯度都准备好了,Reducer在该Bucket上启动一个异步allreduce来计算所有进程的梯度平均值。当所有Buckets准备就绪时,Reducer将阻塞等待所有allreduce操作完成。完成此操作后,平均梯度将写入所有参数的 param.grad 字段里。因此,在反向传播后,不同DDP进程中同一参数对应的梯度值应该是相同的。
- 优化过程:从优化器的角度来看,虽然它只优化了一个本地模型,但所有DDP进程上的模型副本都可以保持同步,这是由于各模型从相同的模型参数出发,并在每次迭代中具有相同的平均梯度。
完!!!!!!!