Pytorch 模型load各种问题解决

出现unexpected key module.xxx.weight问题

有时候你的模型保存时含有 nn.DataParallel时,就会发现所有的dict都会有 module的前缀。
这时候加载含有module前缀的模型时,可能会出错。其实你只要移除这些前缀即可

  pretrained_net = Net_OLD()
  pretrained_net_dict = torch.load(save_path)
  new_state_dict = OrderedDict()
  for k, v in pretrained_net_dict.items():
      name = k[7:] # remove `module.`
      new_state_dict[name] = v
  # load params
  pretrained_net.load_state_dict(new_state_dict)

总结
保存的Dict是按照net.属性.weight来存储的。如果这个属性是一个Sequential,我们可以类似这样net.seqConvs.0.weight来获得。
当然在定义的类中,拿到Sequential的某一层用[], 比如self.seqConvs[0].weight.
strict=False是没有那么智能,遵循有相同的key则赋值,否则直接丢弃。
 

参考:https://blog.csdn.net/Hungryof/article/details/81364487?utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control&dist_request_id=&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值