Pytorch只加载部分参数权重 load (pth文件) & 加载模型不完全匹配&module.后缀问题

当使用Pytorch作为学习框架时:

我们已经建立好自己的model,但是想加载他人的网络模型架构时,总会出现一些问题,其中最为常见的三种问题

1. 网络每层的名称不相同,导致这个问题大体有两种原因

1.1 网络本身命名不同

1.2 有些网络在多块GPU上进行训练会使得每一层网络名称多一个module.的后缀,而有的只在CPU上训练则没有module.这个后缀。

针对上面两种1.1 与1.2 我们需要弄清下面这个函数的原理

model.load_state_dict()

该函数是指当我们已经构造好了一个模型后,可能要加载一些训练好的模型参数。举例子如下:

假设  trained.pth 是一个训练好的网络的模型参数存储载体。model = Net()是我们刚刚生成的一个新模型,我们希望model将trained.pth中的参数加载加载进来,这时我们就需要采用上述函数作为加载函数。

而这个函数加载的数据类型是OrderedDict,所谓OrderedDict,就是一个有着固定顺序的字典,当我们想要自己创建一个OrderedDict时需要按照下述引用OrderedDict。

from collections import OrderedDict

以1.2为例,如果我们想要将待加载模型trained.pth中的层的名字中的module.这个后缀去掉,则需采用以下代码。

new_state_dict = OrderedDict()
for k in 待加载的网络参数的位置(也是OrderedDict类型):
    name = k.replace('module.', '')
    new_state_dict[name] = checkpoint['model_state_dict'].setdefault(k)
self.model.load_state_dict(new_state_dict)

2. 加载模型不完全匹配

当加载模型不完全匹配时,我们可以采用

model.load_state_dict(new_state_dict,strict=False)

其中的strict参数设置为false则使得网络不完全加载目标网络参数,为True则完全粘贴过来,但凡有一点不一致都会报错。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值