抽个空,赶紧把这个错误记一下,怕忘记了.
前几天,很开心地训练了一个较好的模型,说让从中抽出一个测试的API出来.于是,吭哧吭哧地整了一天,发现怎么同一个模型,同一张图片,结果总是不同呢? 懵逼中… …
几经周折,发现了问题所在. 在加载我训练好的模型这里,源代码是这样的:
def load_model(self, model_path, return_list=None):
"""Load the pre-trained model weight
:return:
"""
print(f'Loading model:{model_path}')
checkpoint = torch.load(model_path)
self.model.load_state_dict(checkpoint['model'], strict=False)
看着很正常的代码,也没有什么问题.但是问题有一处:
self.model.load_state_dict(checkpoint['model'], strict=False)
这里的strict=False
这个参数,最好不要设为False, 这样如果加载模型有问题,还会给你报出错误,但是如果设为False,训练参数如果没有和model
对应上,也不会提示错误的.- 当我改为
strict=True
时,问题立马就看到了:
RuntimeError: Error(s) in loading state_dict for VGG16:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.17.weight", "module.features.17.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.24.weight", "module.features.24.bias", "module.features.26.weight", "module.features.26.bias", "module.features.28.weight", "module.features.28.bias", "module.fc.weight", "module.fc.bias".
这样就好说了,代码更改如下:
def load_model(self, model_path):
"""Load the pre-trained model weight
:param model_path:
:return:
"""
checkpoint = torch.load(model_path, map_location=self.device_name)['model']
# TODO:这里需要具体了解原因在哪里?
checkpoint_parameter_name = list(checkpoint.keys())[0]
model_parameter_name = next(self.model.named_parameters())[0]
is_checkpoint = checkpoint_parameter_name.startswith('module.')
is_model = model_parameter_name.startswith('module.')
if is_checkpoint and not is_model:
# 移除checkpoint模型里面参数
new_parameter_check = OrderedDict()
for key, value in checkpoint.items():
if key.startswith('module.'):
new_parameter_check[key[7:]] = value
self.model.load_state_dict(new_parameter_check)
elif not is_checkpoint and is_model:
# 添加module.参数
new_parameter_dict = OrderedDict()
for key, value in checkpoint.items():
if not key.startswith('module.'):
key = 'module.' + key
new_parameter_dict[key] = value
else:
self.model.load_state_dict(checkpoint)
return self.model
小结:
出现这种情况的原因是我在训练时,使用了self.model = torch.nn.DataParallel(self.model)
, 经过测试发现,当使用GPU计算时,模型经过该函数,出来的参数,就带有了module的前缀,模型在此情形之下,保存的权重参数也自然带有了module前缀.这里需要格外注意.