发现问题
最近在整理毕业设计demo的时候,发现demo不能跑通了。报错如下:
AttributeError: 'ReLU' object has no attribute 'threshold'
分析问题
demo在一个月前是可以跑通的,觉得很奇怪。在这期间,我仅仅是重新安装了Anaconda。联想到前几天看到师兄的票圈说pytorch自己更新了模型部分的代码,怀疑是pytorch版本的问题。目前的pytorch版本为1.0.0,以前的好像是0.4.0。
这里要写一下题外话:目前电脑安装的CUDA版本是10.0,之前CUDA版本由10.0换成8.0后,torch.cuda()会卡死,遂又换回了10.0。然而,我在pytorch官网上并没有看到对应cuda10.0的版本为0.4.0的pytorch。因此,更换pytorch版本这条路是走不通了。
解决问题
再回看报错内容,ReLU是模型中的结构。猜想是旧模型的结构在新的torch版本中水土不服了。此外,模型的加载并没有问题,报错是出现在模型的forward中的。因此可以考虑将旧模型的参数复制到新模型结构中。具体步骤为:
- 加载旧模型参数
- 创建一个新模型
- 将旧模型参数复制到新模型中
- 保存新模型
代码实现如下:
from torch import load, save
from torchreid.models.splitresnet import ResNet50Split
model = load('log/best_model.pth.tar')
new_model = ResNet50Split(num_classes=40)
model_dict = model.state_dict()
new_model_dict = new_model.state_dict()
model_dict = {k: v for k, v in model_dict.items() if k in new_model_dict.keys()}
new_model_dict.update(model_dict)
save(new_model, 'log/new_model.pth.tar')
用新的模型代替旧模型,demo成功运行!