1.遇到的问题
训练好的模型,保存了参数,加载模型参数,报错:
2.错误分析
训练代码中,模型声明的时候,放到gpu上面,使用了nn.DataParallel,表示多块gpu分布式训练。保存模型的时候,使用我现在的保存方式,在加载模型的key值中会多一个module.这七个字符。
model = fusion()
# model = model.cuda()
model=nn.DataParallel(model.to(device),device_ids=gpus,output_device=gpus[0])
train保存模型:
torch.save(model.state_dict(), f'{args.save_path}/fusion_model_epoch_{epoch}.pth')
test中加载模型:
# model.load_state_dict(torch.load(args.fusion_pretrained)) #,strict=False
3.解决方法
方法一:(成功)
1.去掉key值的前七个字符
# original saved file with DataParallel
state_dict = torch.load(model_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
方法二:
最好使用torch.save(model.module.state_dict(), model_out_path)保存模型,这样等到需要测试网络时,加载模型时用model.load_state_dict(torch.load(PATH, map_location=device))直接加载模型。
原文链接
4.拓展
有的时候是因为改动了模型,加载原来的训练好的模型参数,导致参数不匹配。简单易懂的解释如下: model.load_state_dict(state_dict, strict=False)