CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
- 在模块级实现数据并行。
- 该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定模块的应用程序(其他对象将在每个设备上复制一次)。在前向传播中,模块被复制到每个设备上,每个副本处理输入的一部分。在反向传播过程中,来自每个副本的梯度被累加到原始模块中。
- 批处理大小应该大于所使用的gpu数量。
- 允许将任意位置和关键字输入传递到
DataParallel
中,但有些类型是专门处理的 tensor
将分散到指定dim
(默认为0)。tuple
,list
以及dict
类型将浅拷贝。其他类型将在不同的线程之间共享,如果在模型的前向传播中写入,则可能被破坏。- 并行模块必须在
device_ids[0]
上有它的parameters
和buffers
,然后才能运行这个DataParallel
模块。
Parameters
module (Module) – module to be parallelized
device_ids (list of python:int or torch.device) – CUDA devices (default: all devices)
output_device (int or torch.device) – device location of output (default: device_ids[0])
Example
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(input_var) # input_var can be on any device, including CPU