使用 torch.nn.DataParallel 训练和保存的模型,其 key 中多了’module’,在加载到单GPU或CPU环境中,会报错找不到key,需要将它去掉。
def load_model_without_module(model_state_file):
from collections import OrderedDict
new_checkpoint = OrderedDict()
checkpoint = torch.load(model_state_file)['state_dict']
import pdb; pdb.set_trace()
for k, v in checkpoint.items():
name = k[7:] # remove module.
new_checkpoint[name] = v
return new_checkpoint
def main():
...
model =
model_state_file =
# model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(load_model_without_module(model_state_file))
参考:
https://discuss.pytorch.org/t/keyerror-state-dict/18220/2