参考 class torch.nn.parallel.DistributedDataParallel() - 云+社区 - 腾讯云
torch.nn.parallel.
DistributedDataParallel
(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False)[source]
Implements distributed data parallelism that is based on torch.distributed
package at the module level.
This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension.
The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.
The batch size should be larger than the number of GPUs used locally.
See also: Basics and Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel. The same constraints on input as in torch.nn.DataParallel apply.
Creation of this class requires that torch.distributed
to be already initialized, by calling torch.distributed.init_process_group().
DistributedDataParallel
is proven to be significantly faster than