pytorch加载保存模型

项目场景:

pytorch加载保存模型

问题描述:

加载保存模型

解决方案:

非常感谢这个链接
不错的记录链接:https://blog.csdn.net/comli_cn/article/details/107516740
配合该链接食用效果更加

我这里的train替换了他的predict

一、保存整个模型

虽然占用内存大,但我觉得比仅仅保存模型参数省事

保存路径:

PATH_all = '/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model_all/model_AD_NC_1.pt'

1.保存模型:

先正常运行模型:

model = Classifier().cuda()
torch.save(model, PATH_all)

2.加载模型:

注释掉 torch.save(model, PATH_all)
不要注释 model = Classifier().cuda()
然后在 你要 train model 前加上

new_m = torch.load('/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model_all/model_AD_NC_1.pt')

然后把train命令中的model名字改成你加载的模型名字

train( new_m, train_loader, val_loader, test_loader)

二、仅仅保存模型参数

保存路径:

save_path = '/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model/AD_NC_2.pt'

1.保存模型:

先正常运行模型:

model = Classifier().cuda()
torch.save(model.state_dict(), save_path)

2.加载模型:

注意,在调用模型参数时(因为调用的不是模型),需要先实例化

m_state_dict = torch.load('/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model/AD_NC.pt')
new_m = Classifier().cuda()
new_m.load_state_dict(m_state_dict)    #实例化
train( new_m, train_loader, val_loader, test_loader)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值