torch.load()
报错 Missing key(s) pytorch
错误情况: 在加载预训练模型时出错
RuntimeError: Error(s) in loading state_dict for :
Missing key(s) in state_dict: “features.0.weight” …
Unexpected key(s) in state_dict: “module.features.0.weight” …
错误原因:
使用nn.DataParallel包装后的模型参数的关键字会比没用 nn.DataParallel 包装的模型参数的关键字前面多一个"module."
解决方法:
-
使用 net 加载 nn.DataParallel(net) 训练出来的模型:
-
把 module. 删掉
# original saved file with DataParallel state_dict = torch.load('model_path') # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params net.load_state_dict(new_state_dict)
checkpoint = torch.load('model_path') for key in list(checkpoint.keys()): if 'model.' in key: checkpoint[key.replace('model.', '')] = checkpoint[key] del checkpoint[key] net.load_state_dict(checkpoint)
-
加载模型时使用 nn.DataParallel
checkpoint = torch.load('model_path') net = torch.nn.DataParallel(net) net.load_state_dict(checkpoint)
-
-
使用 nn.DataParallel(net) 加载 net 训练出的模型:
-
保存权重前增加 module
使用 torch.save() 保存权重时,通过 model.module.state_dict() 获取模型权重
torch.save(net.module.state_dict(), 'model_path')
-
在使用nn.DataParallel之前就先读取模型,然后再使用nn.DataParallel
net.load_state_dict(torch.load('model_path')) net = nn.DataParallel(net, device_ids=[0, 1])
-
手动添加 module.
net = nn.DataParallel(net) from collections import OrderedDict new_state_dict = OrderedDict() state_dict =savepath #预训练模型路径 for k, v in state_dict.items(): # 手动添加“module.” if 'module' not in k: k = 'module.'+k else: # 调换module和features的位置 k = k.replace('features.module.', 'module.features.') new_state_dict[k]=v net.load_state_dict(new_state_dict)
-