开始之前:https://pytorch.org/tutorials/beginner/dist_overview.html
CLASS torch.nn.parallel.
DistributedDataParallel
(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False)
基于torch.distributed包,在模块水平实现分布式数据并行。
该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定的模块。模块被复制到每台机器和每台设备上,每个这样的副本处理输入的一部分。在反向传播时,每个节点的梯度被平均。
批处理大小应该大于本地使用的GPU数量。
另请参阅: Basics 和 Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel. 如 torch.nn.DataParallel那样对输入的限制同样适用于torch.nn.parallel.DistributedDataParallel
.
Creation of this class requires that torch.distributed
to be already initialized, by calling torch.distributed.init_process_group()
.
DistributedDataParallel
is proven to be significantly faster than torch.nn.DataParallel
for single-node multi-GPU data parallel training.
To use DistributedDataParallel
on a host with N GPUs, you should spawn up N
processes, ensuring that each process exclusively works on a singl