Pytorch中gpu的并行运算
常用的最多的就是,多块GPU训练同一个网络模型。Pytorch中的并行运算。
1. 多GPU输入数据并行运算
一般使用torch.nn.DataParallel
,例如:
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
2. 推荐GPU设置方式:
- 单卡
使用CUDA_VISIBLE_DEVICES
指定GPU,然后.cuda()
不传入参数import os os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids # gpu_ids参数为int类型,如 0 model.cuda()
- 多GPU输入数据并行处理
import os gpu_list = '0,1,2,3' os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list device_ids = [0, 1, 2, 3] net = torch.nn.DataParallel(net, device_ids=device_ids)
3. 保存加载多GPU网络
net = torch.nn.Linear(10,1) # 构造网络
net = torch.nn.DataParallel(net, device_ids=[0,1, 2, 3])
torch.save(net.module.state_dict(), './model/multiGPU.pth') #保存网络
# 加载网络
new_net = torch.nn.Linear(10,1)
new_net.load_state_dict(torch.load("./model/multiGPU.pth"))