RuntimeError: Error(s) in loading state_dict for
背景:
使用别人训练好的训练权重进行检测,发现出现上述问题
问题描述:
解决办法:
1.将 load_state_dict 中 strict 参数设置为 False
修改前:
model.load_state_dict(torch.load(model_path))
修改后:
model.load_state_dict(torch.load(model_path), False)
但我本人在使用上述方法后出现了预测结果表现不佳,训练权重似乎并没有使用上
出现这种情况可以用第二种方法
2.未将模型正确导入
博主是新人,在CSDN上搜了许多结果,但是仍然是没有解决。
最后无意间发现了这个超级低级的错误!
在直接使用预先训练好的权重的时候,没有载入相关的模型。
例如博主是使用了Densenet进行一个新的分类训练,但我仍然使用的是之前的模型
from models.SwinT import swin_base_patch4_window7_224 as create_model
于是我找到现在使用的模型将其导入
from models.Densenet import densenet121 as create_model1
问题自然就解决了
至于模型如何导入,博主是网上下载的模型,在你训练的文件中一般有一个model文件
找到你具体使用的权重,例如我使用的是densenet121
然后使用
from models.Densenet import densenet121 as create_model1
其中 models是我项目中的一个文件夹,后面的Densenet是我改名后的model文件,create_model1是我新定义的一个模型