个人使用pytorch
的时候需要用到多GPU
运行,简要说明一下应用情景:
- 单
GPU
不够用,你需要将模型存储在多个GPU
上; - 当模型初始化后运行在多个GPU上,你要加载dict模型参数文件。
第一个情景,我们使用 nn.DataParallel
来解决,直接上例子:
# 使用nvidia-smi查看可用的设备
CUDA_DEVICE_1 = 0
CUDA_DEVICE_2 = 1
# 模型初始化
net = model()
# 模型进入DataParallel,device_ids指明设备号列表
net = nn.DataParallel(net