一、问题
训练好深度学习模型后,在推理测试会遇到类似下面加载模型失败的问题:
File "E:\workshop\CBDNet\predict_for_GPU.py", line 54, in <module>
model.load_state_dict(model_info['state_dict'])
File "D:\anaconda\envs\pytorch\Lib\site-packages\torch\nn\modules\module.py", line 2153, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.fcn.fcn.0.weight", "module.fcn.fcn.0.bias", "module.fcn.fcn.2.weight", "module.fcn.fcn.2.bias", "module.fcn.fcn.4.weight", "module.fcn.fcn.4.bias", "module.fcn.fcn.6.weight", "module.fcn.fcn.6.bias", "module.fcn.fcn.8.weight", "module.fcn.fcn.8.bias", "module.unet.inc.0.conv.0.weight", "module.unet.inc.0.conv.0.bias", "module.unet.inc.1.conv.0.weight", "module.unet.inc.1.conv.0.bias", "module.unet.conv1.0.conv.0.weight", "module.unet.conv1.0.conv.0.bias", "module.unet.conv1.1.conv.0.weight", "module.unet.conv1.1.conv.0.bias", "module.unet.conv1.2.c
二、问题分析
这个错误表明在加载 state_dict 时,模型的一些键(如 module.fcn.fcn.0.weight 等)无法匹配。这通常是因为你在加载模型时遇到了 DataParallel 模型和非 DataParallel 模型之间的不匹配问题。
如果你在多GPU上训练时使用了 torch.nn.DataParallel,模型的 state_dict 中的参数会带有 module. 前缀。而在加载到单GPU模型时,前缀不需要。
三、解决办法
3.1 单卡推理测试
当仅使用单卡推理测试,加载模型时,可以通过移除 module. 前缀来匹配当前的模型结构。手动修改 state_dict,去掉这些前缀。
实例代码见下:
# 加载模型
model_info = torch.load('path_to_model.pth')
# 获取 state_dict
state_dict = model_info['state_dict']
# 移除 'module.' 前缀
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v # 去掉 'module.' 前缀
else:
new_state_dict[k] = v
# 加载到模型
model.load_state_dict(new_state_dict)
3.2 多卡推理测试
如果你仍然希望在多GPU设备上运行模型,并且原来的模型是使用 torch.nn.DataParallel 训练的,可以继续使用 DataParallel 包装你的模型,确保参数名称不变,则直接使用model.load_state_dict()加载模型,前提是你自己的电脑有至少两张显卡可以正常调用。
加载多卡模型实例代码见下:
model = torch.nn.DataParallel(model)
model.load_state_dict(model_info['state_dict'])
四、总结
根据自己的电脑显卡块数集使用场景选择一张方式。如果是在单GPU上运行,推荐3.1方法;如果是在多GPU上运行,3.2方法更加适合。