pytorch保存和加载模型

之前想学习保存和加载模型的代码,在知乎上看到一个回答,发现两行代码就可以搞定,于是兴冲冲的加上了:

torch.save(model, "model.pth.tar") 
model_dict=torch.load("model.pth.tar")

然后就大胆的去训练了,结果训练结束,准备load时,发现load得到的结果,就只有模型的结构,参数完全没保存下来…
(哎,当时看到答主说这种方式是保存了整个网络,就以为整个网络必然包括参数啊,谁知道仅仅是结构)

于是换了一种方式:

checkpoint = {
    "model_struct": model,
    "model_param": model.state_dict(),
    "model_cfg": config}
torch.save(checkpoint, “model.ckpt")

这种方式是自己建立一个字典checkpoint,然后分别保存模型结构model_struct、模型参数model_param和相关配置model_cfg,然后保存为.ckpt文件(至于为何.pth.tar.ckpt到底有什么不一样,暂时还不清楚)

这次的教训就是,单保存模型是不能保存网络参数的,需要调用模型的.state_dict()属性将参数拿出来

参考文章:
https://zhuanlan.zhihu.com/p/38056115

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值