了解torch.nn.DataParallel

CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

  • 在模块级实现数据并行。
  • 该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定模块的应用程序(其他对象将在每个设备上复制一次)。在前向传播中,模块被复制到每个设备上,每个副本处理输入的一部分。在反向传播过程中,来自每个副本的梯度被累加到原始模块中。
  • 批处理大小应该大于所使用的gpu数量。
  • 允许将任意位置和关键字输入传递到DataParallel中,但有些类型是专门处理的
  • tensor将分散到指定dim(默认为0)。tuplelist以及dict类型将浅拷贝。其他类型将在不同的线程之间共享,如果在模型的前向传播中写入,则可能被破坏。
  • 并行模块必须在device_ids[0]上有它的parametersbuffers,然后才能运行这个DataParallel模块。

Parameters

  • module (Module) – module to be parallelized
  • device_ids (list of python:int or torch.device) – CUDA devices (default: all devices)
  • output_device (int or torch.device) – device location of output (default: device_ids[0])

Example

net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(input_var)  # input_var can be on any device, including CPU

pytorch文档

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值