如何用pytorch选择性加载神经网络的权值

如何用pytorch选择性加载神经网络的相应层权值

前言

在什么情况下需要用到这个呢,两种情况,第一种,在用迁移学习训练一个模型的时候,如果预训练模型权重结构和构建的网络结构不完全相同,即存在某个层或者某几层的结构不一样,这个时候并不需要丢掉这个权重,而可以选择性加载结构相同的网络层。第二种,在打算换用深度学习框架的时候,比如你想要把pytorch的权重文件转换成tensorflow或者其他框架的权重文件时,也可以用这个方法把pytorch的权重一层一层的读取出来,然后换成tensorflow的权重文件。

步骤

step1:理解pytorch的权重文件数据存储结构

如果只是单纯的pytorch权值文件,其实它是下面这个样子的。(以mobilenetV3的权值为例)
在这里插入图片描述
可以看出来,是用有序字典存储权值的。但是,一般在保存权值的时候,不会只是保存权值,还会保存训练的相关信息,比如训练次数epoch等。也就是说,会在外面再套一层字典,如下面所示。
在这里插入图片描述
知道了权值存储结构之后,接下来就好办了。

step2:编程实现

以mobilenetV3网络为例,我们在导入官方预训练权值的时候,只导入前面的网络层权重,不导入最后一层网络层的权重(因为最后一层网络层的神经元个数和你的分类类别数要相等,所以这一层不能随便用别人的预训练权重)。由于官方给的预训练权重每一层的网络名前面多了一个“module.”,所以我们还需要把每层网络名里面的“module.”给去掉,不然和我们自己定义的网络名不相等,会导致所有网络层的权值都导入不进,而且如果你的net.load_state_dict()里面的strict参数设置成False时,你还不自知,你以为你导入了权值,实际上你就是在从头训练。代码如下。

import torch
from mobilenetv3 import MobileNetV3_Small 
from collections import OrderedDict

net = MobileNetV3_Small(num_classes=7)
model_weight_path = "./mbv3_small.pth.tar" #预训练权重
pre_weights = torch.load(model_weight_path) 
print('已训练次数:',pre_weights['epoch'])
#----------------------------------------------------------------------------------------------------------#
#检查并打印网络层
new_state_dict=OrderedDict()
for k,v in pre_weights['state_dict'].items():
    #print(k)
    if k !='module.linear4.weight' and k !='module.linear4.bias':
        name=k[7:] #去除网络名中的'module.'
        new_state_dict[name]=v
        print(name) #打印每一层的网络名
missing_keys, unexpected_keys = net.load_state_dict(new_state_dict, strict=False)
print('去除最后一层神经网络:',missing_keys)
print('意外的神经层:',unexpected_keys)
#---------------------------------------------------------------------------------------------------------#

运行结果如下图所示。

已训练次数: 196
conv1.weight    
bn1.weight      
bn1.bias        
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
bneck.0.se.se.1.weight
bneck.0.se.se.2.weight
bneck.0.se.se.2.bias
bneck.0.se.se.2.running_mean
bneck.0.se.se.2.running_var
bneck.0.se.se.2.num_batches_tracked
bneck.0.se.se.4.weight
bneck.0.se.se.5.weight
bneck.0.se.se.5.bias
bneck.0.se.se.5.running_mean
bneck.0.se.se.5.running_var
bneck.0.se.se.5.num_batches_tracked
bneck.0.conv1.weight
bneck.0.bn1.weight
bneck.0.bn1.bias
bneck.0.bn1.running_mean
bneck.0.bn1.running_var
bneck.0.bn1.num_batches_tracked
bneck.0.conv2.weight
bneck.0.bn2.weight
bneck.0.bn2.bias
bneck.0.bn2.running_mean
bneck.0.bn2.running_var
bneck.0.bn2.num_batches_tracked
bneck.0.conv3.weight
bneck.0.bn3.weight
bneck.0.bn3.bias
bneck.0.bn3.running_mean
bneck.0.bn3.running_var
bneck.0.bn3.num_batches_tracked
bneck.1.conv1.weight
bneck.1.bn1.weight
bneck.1.bn1.bias
bneck.1.bn1.running_mean
bneck.1.bn1.running_var
bneck.1.bn1.num_batches_tracked
bneck.1.conv2.weight
bneck.1.bn2.weight
bneck.1.bn2.bias
bneck.1.bn2.running_mean
bneck.1.bn2.running_var
bneck.1.bn2.num_batches_tracked
bneck.1.conv3.weight
bneck.1.bn3.weight
bneck.1.bn3.bias
bneck.1.bn3.running_mean
bneck.1.bn3.running_var
bneck.1.bn3.num_batches_tracked
bneck.2.conv1.weight
bneck.2.bn1.weight
bneck.2.bn1.bias
bneck.2.bn1.running_mean
bneck.2.bn1.running_var
bneck.2.bn1.num_batches_tracked
bneck.2.conv2.weight
bneck.2.bn2.weight
bneck.2.bn2.bias
bneck.2.bn2.running_mean
bneck.2.bn2.running_var
bneck.2.bn2.num_batches_tracked
bneck.2.conv3.weight
bneck.2.bn3.weight
bneck.2.bn3.bias
bneck.2.bn3.running_mean
bneck.2.bn3.running_var
bneck.2.bn3.num_batches_tracked
bneck.3.se.se.1.weight
bneck.3.se.se.2.weight
bneck.3.se.se.2.bias
bneck.3.se.se.2.running_mean
bneck.3.se.se.2.running_var
bneck.3.se.se.2.num_batches_tracked
bneck.3.se.se.4.weight
bneck.3.se.se.5.weight
bneck.3.se.se.5.bias
bneck.3.se.se.5.running_mean
bneck.3.se.se.5.running_var
bneck.3.se.se.5.num_batches_tracked
bneck.3.conv1.weight
bneck.3.bn1.weight
bneck.3.bn1.bias
bneck.3.bn1.running_mean
bneck.3.bn1.running_var
bneck.3.bn1.num_batches_tracked
bneck.3.conv2.weight
bneck.3.bn2.weight
bneck.3.bn2.bias
bneck.3.bn2.running_mean
bneck.3.bn2.running_var
bneck.3.bn2.num_batches_tracked
bneck.3.conv3.weight
bneck.3.bn3.weight
bneck.3.bn3.bias
bneck.3.bn3.running_mean
bneck.3.bn3.running_var
bneck.3.bn3.num_batches_tracked
bneck.4.se.se.1.weight
bneck.4.se.se.2.weight
bneck.4.se.se.2.bias
bneck.4.se.se.2.running_mean
bneck.4.se.se.2.running_var
bneck.4.se.se.2.num_batches_tracked
bneck.4.se.se.4.weight
bneck.4.se.se.5.weight
bneck.4.se.se.5.bias
bneck.4.se.se.5.running_mean
bneck.4.se.se.5.running_var
bneck.4.se.se.5.num_batches_tracked
bneck.4.conv1.weight
bneck.4.bn1.weight
bneck.4.bn1.bias
bneck.4.bn1.running_mean
bneck.4.bn1.running_var
bneck.4.bn1.num_batches_tracked
bneck.4.conv2.weight
bneck.4.bn2.weight
bneck.4.bn2.bias
bneck.4.bn2.running_mean
bneck.4.bn2.running_var
bneck.4.bn2.num_batches_tracked
bneck.4.conv3.weight
bneck.4.bn3.weight
bneck.4.bn3.bias
bneck.4.bn3.running_mean
bneck.4.bn3.running_var
bneck.4.bn3.num_batches_tracked
bneck.5.se.se.1.weight
bneck.5.se.se.2.weight
bneck.5.se.se.2.bias
bneck.5.se.se.2.running_mean
bneck.5.se.se.2.running_var
bneck.5.se.se.2.num_batches_tracked
bneck.5.se.se.4.weight
bneck.5.se.se.5.weight
bneck.5.se.se.5.bias
bneck.5.se.se.5.running_mean
bneck.5.se.se.5.running_var
bneck.5.se.se.5.num_batches_tracked
bneck.5.conv1.weight
bneck.5.bn1.weight
bneck.5.bn1.bias
bneck.5.bn1.running_mean
bneck.5.bn1.running_var
bneck.5.bn1.num_batches_tracked
bneck.5.conv2.weight
bneck.5.bn2.weight
bneck.5.bn2.bias
bneck.5.bn2.running_mean
bneck.5.bn2.running_var
bneck.5.bn2.num_batches_tracked
bneck.5.conv3.weight
bneck.5.bn3.weight
bneck.5.bn3.bias
bneck.5.bn3.running_mean
bneck.5.bn3.running_var
bneck.5.bn3.num_batches_tracked
bneck.6.se.se.1.weight
bneck.6.se.se.2.weight
bneck.6.se.se.2.bias
bneck.6.se.se.2.running_mean
bneck.6.se.se.2.running_var
bneck.6.se.se.2.num_batches_tracked
bneck.6.se.se.4.weight
bneck.6.se.se.5.weight
bneck.6.se.se.5.bias
bneck.6.se.se.5.running_mean
bneck.6.se.se.5.running_var
bneck.6.se.se.5.num_batches_tracked
bneck.6.conv1.weight
bneck.6.bn1.weight
bneck.6.bn1.bias
bneck.6.bn1.running_mean
bneck.6.bn1.running_var
bneck.6.bn1.num_batches_tracked
bneck.6.conv2.weight
bneck.6.bn2.weight
bneck.6.bn2.bias
bneck.6.bn2.running_mean
bneck.6.bn2.running_var
bneck.6.bn2.num_batches_tracked
bneck.6.conv3.weight
bneck.6.bn3.weight
bneck.6.bn3.bias
bneck.6.bn3.running_mean
bneck.6.bn3.running_var
bneck.6.bn3.num_batches_tracked
bneck.6.shortcut.0.weight
bneck.6.shortcut.1.weight
bneck.6.shortcut.1.bias
bneck.6.shortcut.1.running_mean
bneck.6.shortcut.1.running_var
bneck.6.shortcut.1.num_batches_tracked
bneck.7.se.se.1.weight
bneck.7.se.se.2.weight
bneck.7.se.se.2.bias
bneck.7.se.se.2.running_mean
bneck.7.se.se.2.running_var
bneck.7.se.se.2.num_batches_tracked
bneck.7.se.se.4.weight
bneck.7.se.se.5.weight
bneck.7.se.se.5.bias
bneck.7.se.se.5.running_mean
bneck.7.se.se.5.running_var
bneck.7.se.se.5.num_batches_tracked
bneck.7.conv1.weight
bneck.7.bn1.weight
bneck.7.bn1.bias
bneck.7.bn1.running_mean
bneck.7.bn1.running_var
bneck.7.bn1.num_batches_tracked
bneck.7.conv2.weight
bneck.7.bn2.weight
bneck.7.bn2.bias
bneck.7.bn2.running_mean
bneck.7.bn2.running_var
bneck.7.bn2.num_batches_tracked
bneck.7.conv3.weight
bneck.7.bn3.weight
bneck.7.bn3.bias
bneck.7.bn3.running_mean
bneck.7.bn3.running_var
bneck.7.bn3.num_batches_tracked
bneck.8.se.se.1.weight
bneck.8.se.se.2.weight
bneck.8.se.se.2.bias
bneck.8.se.se.2.running_mean
bneck.8.se.se.2.running_var
bneck.8.se.se.2.num_batches_tracked
bneck.8.se.se.4.weight
bneck.8.se.se.5.weight
bneck.8.se.se.5.bias
bneck.8.se.se.5.running_mean
bneck.8.se.se.5.running_var
bneck.8.se.se.5.num_batches_tracked
bneck.8.conv1.weight
bneck.8.bn1.weight
bneck.8.bn1.bias
bneck.8.bn1.running_mean
bneck.8.bn1.running_var
bneck.8.bn1.num_batches_tracked
bneck.8.conv2.weight
bneck.8.bn2.weight
bneck.8.bn2.bias
bneck.8.bn2.running_mean
bneck.8.bn2.running_var
bneck.8.bn2.num_batches_tracked
bneck.8.conv3.weight
bneck.8.bn3.weight
bneck.8.bn3.bias
bneck.8.bn3.running_mean
bneck.8.bn3.running_var
bneck.8.bn3.num_batches_tracked
bneck.9.se.se.1.weight
bneck.9.se.se.2.weight
bneck.9.se.se.2.bias
bneck.9.se.se.2.running_mean
bneck.9.se.se.2.running_var
bneck.9.se.se.2.num_batches_tracked
bneck.9.se.se.4.weight
bneck.9.se.se.5.weight
bneck.9.se.se.5.bias
bneck.9.se.se.5.running_mean
bneck.9.se.se.5.running_var
bneck.9.se.se.5.num_batches_tracked
bneck.9.conv1.weight
bneck.9.bn1.weight
bneck.9.bn1.bias
bneck.9.bn1.running_mean
bneck.9.bn1.running_var
bneck.9.bn1.num_batches_tracked
bneck.9.conv2.weight
bneck.9.bn2.weight
bneck.9.bn2.bias
bneck.9.bn2.running_mean
bneck.9.bn2.running_var
bneck.9.bn2.num_batches_tracked
bneck.9.conv3.weight
bneck.9.bn3.weight
bneck.9.bn3.bias
bneck.9.bn3.running_mean
bneck.9.bn3.running_var
bneck.9.bn3.num_batches_tracked
bneck.10.se.se.1.weight
bneck.10.se.se.2.weight
bneck.10.se.se.2.bias
bneck.10.se.se.2.running_mean
bneck.10.se.se.2.running_var
bneck.10.se.se.2.num_batches_tracked
bneck.10.se.se.4.weight
bneck.10.se.se.5.weight
bneck.10.se.se.5.bias
bneck.10.se.se.5.running_mean
bneck.10.se.se.5.running_var
bneck.10.se.se.5.num_batches_tracked
bneck.10.conv1.weight
bneck.10.bn1.weight
bneck.10.bn1.bias
bneck.10.bn1.running_mean
bneck.10.bn1.running_var
bneck.10.bn1.num_batches_tracked
bneck.10.conv2.weight
bneck.10.bn2.weight
bneck.10.bn2.bias
bneck.10.bn2.running_mean
bneck.10.bn2.running_var
bneck.10.bn2.num_batches_tracked
bneck.10.conv3.weight
bneck.10.bn3.weight
bneck.10.bn3.bias
bneck.10.bn3.running_mean
bneck.10.bn3.running_var
bneck.10.bn3.num_batches_tracked
conv2.weight
bn2.weight
bn2.bias
bn2.running_mean
bn2.running_var
bn2.num_batches_tracked
linear3.weight
linear3.bias
bn3.weight
bn3.bias
bn3.running_mean
bn3.running_var
bn3.num_batches_tracked
去除最后一层神经网络: ['linear4.weight', 'linear4.bias']
意外的神经层: []
  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ai_Taoism

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值