from torch.nn import DataParaller
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
Parameters:
module – module to be parallelized
device_ids – CUDA devices (default: all devices)
output_device – device location of output (default: device_ids[0])
Variables:
module (Module) – the module to be parallelized
可以实现单机多个GPU的并行训练,在forward阶段,将输入batch分平分为几块数据分别送往指定的GPU进行运算。在backward阶段收集各个GPU反向传播的梯度并求和,作用在原模型上,实现加速训练。
例子:
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(input_var)
据说新函数
torch.nn.parallel.distributedDataParallel()表现的更好,还支持分布式训练(多机)