pytorch多GPU计算

pytorch多GPU计算

如果正确安装了NVIDIA驱动,我们可以通过在命令行输入nvidia-smi命令来查看当前计算机上的全部GPU

定义一个模型:

import torch
net = torch.nn.Linear(10, 1).cuda()
net

output:

Linear(in_features=10, out_features=1, bias=True)

要想使用PyTorch进行多GPU计算,最简单的方法是直接用torch.nn.DataParallel将模型wrap一下即可:

net = torch.nn.DataParallel(net)
net

output:

DataParallel(
  (module): Linear(in_features=10, out_features=1, bias=True)
)

这时,默认所有存在的GPU都会被使用。

指定使用的GPU可以使用以下方式:

torch.nn.DataParallel(net, device_ids=[0, 1])

这表示只使用0、1号显卡

多GPU模型的保存与加载

torch.save(net.state_dict(), "./test_model.pt")

加载模型前我们一般要先进行一下模型定义,此时的new_net并没有使用多GPU:

new_net = torch.nn.Linear(10, 1)
new_net.load_state_dict(torch.load("./test_model.pt"))

报错

RuntimeError: Error(s) in loading state_dict for Linear:
    Missing key(s) in state_dict: "weight", "bias". 
    Unexpected key(s) in state_dict: "module.weight", "module.bias". 

事实上DataParallel也是一个nn.Module,只是这个类其中有一个module就是传入的实际模型。因此当我们调用DataParallel后,模型结构变了。所以直接加载肯定会报错的,因为模型结构对不上。

所以正确的方法是保存的时候只保存net.module:

torch.save(net.module.state_dict(), "./test_model.pt")
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功

或者先将new_net用DataParallel包括以下再用上面报错的方法进行模型加载:

torch.save(net.state_dict(), "./test_model.pt")
new_net = torch.nn.Linear(10, 1)
new_net = torch.nn.DataParallel(new_net)
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功

推荐用第一种方法,因为可以按照普通的加载方法进行正确加载

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值