unexpected key "module.conv1_1.weight" in state_dict

torch加载模型时出现如下错误

异常位置
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
异常信息
File "/data/Muyi/Github/EnlightenGAN/models/base_model.py", line 54, in load_network
network.load_state_dict(torch.load(save_path))
File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "module.conv1_1.weight" in state_dict'

异常原因
最终原因在此:训练时使用GPU,使用了torch.nn.DataParallel(),而此时预测没有使用GPU,即没有使用此模块导致上述异常
if len(gpu_ids) > 0:
    netG.cuda(device=gpu_ids[0])
    netG = torch.nn.DataParallel(netG, gpu_ids)
解决方案
1):加上torch.nn.DataParallel()模块,类似我的问题只需要使用GPU即可正常运行
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
2):将原来字典中module.删除掉
network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})
更改原来代码如下,即可在CPU/GPU下都正常运行
if len(self.gpu_ids):
    network.load_state_dict(torch.load(save_path))
else:
    network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})
  • 11
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值