1. 问题
其实标题不大对,因为模型参数发生了改变,只是不知道。
1.1 首先其他人给出的答案在于BN层的问题,可以通过设置model.eval()解决,但这不是这个的问题。
1.2 这个的问题出在了模型的参数上,因为使用了多GPU训练的torch.nn.DataParallel(model),导致在保存训练好的模型参数时,附带上了module,如下:
Missing key(s) in state_dict: "base.conv1.weight", "base.bn1.weight", "base.bn1.bias", "base.bn1.running_mean", "base.bn1.running_var", "base.layer1.0.conv1.weight", "base.layer1.0.bn1.weight", "base.layer1.0.bn1.bias", "base.layer1.0.bn1.running_mean",
......
Unexpected key(s) in state_dict: "module.base.conv1.weight", "module.base.bn1.weight", "module.base.bn1.bias", "module.base.bn1.running_mean", "module.base.bn1.running_var", "module.base.bn1.num_batches_tracked", "module.base.layer1.0.conv1.weight",
......
如此使得在测试模型时,导入参数的时候会有问题。
1.3 有时候问题不报错,是因为如下代码:
model.load_state_dict(state_dict, strict=False)
其中的参数strict=False就是保证即便导入的参数和模型不匹配,也会导入可以导入的参数。由此导致当你认为模型的输入和模型的参数不变,而模型的输出却在每次运行的结果都不一样。故可以设置为True或者不设置,就会严格执行对比,并且报错。
2. 解决方法
使用如下代码,改变导入参数的名称,去掉module:
ckpt = torch.load(model_path, map_location=map_location)['state_dicts'][0]
# 去除因DataParallel引起的参数module问题
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in ckpt.items():
name = k[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)