如何保存和读取pytorch模型

如何保存和读取pytorch模型


相信大家也会遇到这样的问题吧,在使用pytorch训练自己模型的时候,如果不将我们训练的模型保存起来,我们每一次都是从头开始训练我们的模型,这样真的很麻烦。其实在我的上一篇博客中我已经发现这个问题了。

1.保存模型

#定义保存模型函数
def save_model(the_model,PATH):
    torch.save(the_model.state_dict(),PAT

当我们的模型训练完毕之后,我们只需调用一下该函数就可以了

save_model(cnn,'cnn.pth')
#这里的cnn就是我要保存的训练好的模型,cnn.pth就是要保存为的名称,
#一般来说pytorch的模型后缀都是.pth

2.读取模型

例如我们想要在另外的一个python文件中读取我们之前已经保存好的模型,我们需要先创建一个和之前模型一样的空模型来接收。

import torch
from cnn_test import CNN

best_model=CNN()
#定义一个与之前模型一致结构的模型来接收
best_model.load_state_dict(torch.load('cnn.pth'))
#加载之前的模型,这里的‘cnn.pth’就是我上一步保存的模型文件
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王延凯的博客

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值