PyTorch模型保存torch.save()与加载torch.load()


🤵 AuthorHorizon John

编程技巧篇各种操作小结

🎇 机器视觉篇会变魔术 OpenCV

💥 深度学习篇简单入门 PyTorch

🏆 神经网络篇经典网络模型

💻 算法篇再忙也别忘了 LeetCode


模型保存

torch.save()

1)保存全部

保存整个模型

torch.save(model, path)

2)保存部分

只保存模型训练的参数(不包括网络结构)

torch.save(model.state_dict(), path)

模型加载

torch.load()

1)加载全部

加载整个模型

model = torch.load(path)

2)加载部分

只加载模型训练的参数(不包括网络结构)

model = Net()    # 网络
model = model.to(device)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint)

加载模型GPU和CPU转换

使用 map_location 参数对 GPUCPU 进行转化

1)GPU 转 CPU

model = torch.load(PATH, map_location='cpu')

2)CPU 转 GPU

model = torch.load(PATH, map_location=lambda storage, loc: storage.cuda(0))

3)GPU之间转换

torch.load(PATH, map_location={'cuda:0':'cuda:1'})    # GPU0转到GPU1

  • 3
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Horizon John

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值