pytorch模型导入问题
1、RuntimeError: Error(s) in loading state_dict for DataParallel:
这里说明:训练模型的测试加载模型使用的环境不一样
解决方法:
1、在load_state()函数中加上False
model.load_state(checkpoint,False)
从属性state_dic里复制到这个模块和他的后代,如果strict为True,state_dic的keys必须完全与这个模块的方法返回的keys相同,如果为False则不需要匹配。
这个方法虽然可以将模型导入,但是不能保证导入的模型参数能完全匹配模型,在做测试的时候发现导入的参数只是随机的一部分,只能识别单一分类。
2、修改模型参数的keys使其与这个模型返回的keys一致
from collections import OrderedDict
checkpoint=torch.load('参数路径')
state_dict = checkpoint
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' not in k:
k = 'module.' + k
else:
continue
new_state_dict[k]=v
model.load_state_dict(new_state_dict)
这个方法可以保证导入的模型参数跟模型完全一致。