项目场景:
pytorch加载保存模型问题描述:
加载保存模型解决方案:
非常感谢这个链接
不错的记录链接:https://blog.csdn.net/comli_cn/article/details/107516740
配合该链接食用效果更加
我这里的train替换了他的predict
一、保存整个模型
虽然占用内存大,但我觉得比仅仅保存模型参数省事
保存路径:
PATH_all = '/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model_all/model_AD_NC_1.pt'
1.保存模型:
先正常运行模型:
model = Classifier().cuda()
torch.save(model, PATH_all)
2.加载模型:
注释掉 torch.save(model, PATH_all)
不要注释 model = Classifier().cuda()
然后在 你要 train model 前加上
new_m = torch.load('/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model_all/model_AD_NC_1.pt')
然后把train命令中的model名字改成你加载的模型名字
train( new_m, train_loader, val_loader, test_loader)
二、仅仅保存模型参数
保存路径:
save_path = '/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model/AD_NC_2.pt'
1.保存模型:
先正常运行模型:
model = Classifier().cuda()
torch.save(model.state_dict(), save_path)
2.加载模型:
注意,在调用模型参数时(因为调用的不是模型),需要先实例化
m_state_dict = torch.load('/home/ubuntu/liyafeng/NEW_train/728_fusion/daima/MRI_daima/AD_NC/save_model/AD_NC.pt')
new_m = Classifier().cuda()
new_m.load_state_dict(m_state_dict) #实例化
train( new_m, train_loader, val_loader, test_loader)