pytorch加载预训练模型

首先是键name一致的情况、比较简单:

	#这是我们自己网络模型参数的有序字典形式(网络参数名:值)
 	net_dict = net.state_dict()
    #这是实际加载的预训练好的网络模型参数的有序字典形式    
    pretrained_dict = torch.load(pretrained_path)
    #从预训练的参数中加载我们的网络中需要的模型参数(这个很重要、有时需要冻结某一层的参数、可用这条语句从预训练的整个网络参数中筛选出我们需要的某一层的参数)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
    #字典的updata方法,进行字典的更新(个人感觉不是必要的)
    net_dict.update(pretrained_dict)
    #按照键与键的对应关系、加载网络参数    
    net.load_state_dict(net_dict)

然后是键name不一致的情况、也不难:
第一种情况:
当只想加载某一层的参数时,却发现预训练的模型参数名字 与 某一层的参数名字 仅仅差了 一个前缀的关系,那么可用切片的方式,返回一个新的有序字典,然后根据新的有序字典再加载参数。
先来看一下net网络trans这一层的参数:

model_dict = net.trans.state_dict()
print(model_dict.keys())

结果:
odict_keys([‘conv0.h2h_0.weight’, ‘conv0.h2h_0.bias’, ‘conv0.l2l_0.weight’, ‘conv0.l2l_0.bias’, ‘conv0.bnh_0.weight’, ‘conv0.bnh_0.bias’, ‘conv0.bnh_0.running_mean’, ‘conv0.bnh_0.running_var’, …
再来看一下整个预训练模型的参数:

 file = "C:\\Users\\Hou bin\\Desktop\\MINet_Res50.pth"
 pretrained_dict = torch.load(file, map_location='cpu')
 print(pretrained_dict.keys())

‘trans.conv0.h2h_0.weight’, ‘trans.conv0.h2h_0.bias’, ‘trans.conv0.l2l_0.weight’, ‘trans.conv0.l2l_0.bias’, ‘trans.conv0.bnh_0.weight’, ‘trans.conv0.bnh_0.bias’, ‘trans.conv0.bnh_0.running_mean’,…
有啥区别呢?
就是差了一个前缀!
如何加载?
代码如下:

 model_dict = net.trans.state_dict()#以有序字典形式返回trans这一层的全部参数
 file = "C:\\Users\\Hou bin\\Desktop\\MINet_Res50.pth"
 pretrained_dict = torch.load(file, map_location='cpu')
 pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}  #把预训练网络参数trans这一层的参数名字前的前缀通过切片去掉k[6:]: v,返回新的有序字典、这样加载的时候参数名才能对应啊
 #更新现有的model_dict
 model_dict.update(pretrained_dict)#这一步也可以不需要、下一行代码直接加载pretrained_dict
 net.trans.load_state_dict(model_dict)

第二种情况:
有时候你的模型保存时含有 nn.DataParallel时,就会发现所有的dict都会有 module的前缀。
这时候加载含有module前缀的模型时,可能会出错。其实你只要移除这些前缀即可
在这里插入图片描述

方法同上!

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值