CLASS torch.nn.
DataParallel
(module, device_ids=None, output_device=None, dim=0)
在模块水平实现数据并行。
该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定模块的应用程序(其它对象将在每个设备上复制一次)。在前向传播时,模块被复制到每个设备上,每个副本处理输入的一部分。在反向传播时,来自每个副本的梯度被累加到原始模块中。
批处理大小应该大于所使用的GPU数量。
警告:
在使用多GPU训练时,推荐使用 DistributedDataParallel,而不是使用这个类,即使你仅仅使用一个节点。参考 Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel 和Distributed Data Parallel
允许将任意位置和关键字输入传递给DataParallel,但是有些类型是专门处理的。张量将被分散在指定的dim上(默认是0)。tuple,list,dict将被浅复制。其它类型将在不同的线程中共享,如果在模型前向传递中写入,则可能被破坏。
The parallelized module
must have its parameters and buffers on device_ids[0]
before running this