pytorch存储/加载模型及多GPUs条件下的注意事项

一、存储/加载模型

可以选择保存整体model网络结构和参数

PATH = 'saved_model.pth'
# 保存整个model:
torch.save(model_0, PATH)
# 加载整个model:
model_1 = torch.load(PATH)

or只保存model参数

PATH = 'saved_model.pth'
# 由于只加载参数,因此需要提前定义网络结构,例如:
class Net(nn.Module):
...
# 只保存参数:
torch.save(model_0.state_dict, PATH)
# 分别加载网络和参数:
model_1 = Net()
model_1.load_state_dict(torch.load(PATH), strict=False)  # 只会加载键值相同的参数

二、多GPUs下加载过程出错

使用 torch.nn.DataParallel(model_0, device_ids=[0, 1])语句后,加载的模型变量会多“module”关键字,例如:

model_0 = torch.load('resnet50.pth')
model_1 = torch.nn.DataParallel(model_0, device_ids=[0, 1])
torch.save(model_1, PATH)
model_2 = torch.load(PATH)
model_3 = model_2.module

调试界面显示:
在这里插入图片描述
如果直接加载model_2,则可能会因为多了module关键字而报错:

IndexError: list index out of range

其实就是网络结构和参数对不上了(因为多了module)
可以使用model_3 = model_2.module语句,去掉module。

注意:
(1) 每次使用 torch.nn.DataParallel(model_0, device_ids=[0, 1])都会生成一个module关键字,所以用了几次就要去掉几次;

(2) 哪怕 torch.nn.DataParallel(model_0, device_ids=[0,])也会有module的,如果只想用1个GPU,直接用tensor.cuda()

参考

https://blog.csdn.net/CV_YOU/article/details/86670188
https://blog.csdn.net/qq_37959202/article/details/105104278

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值