如何用pytorch选择性加载神经网络的相应层权值
前言
在什么情况下需要用到这个呢,两种情况,第一种,在用迁移学习训练一个模型的时候,如果预训练模型权重结构和构建的网络结构不完全相同,即存在某个层或者某几层的结构不一样,这个时候并不需要丢掉这个权重,而可以选择性加载结构相同的网络层。第二种,在打算换用深度学习框架的时候,比如你想要把pytorch的权重文件转换成tensorflow或者其他框架的权重文件时,也可以用这个方法把pytorch的权重一层一层的读取出来,然后换成tensorflow的权重文件。
步骤
step1:理解pytorch的权重文件数据存储结构
如果只是单纯的pytorch权值文件,其实它是下面这个样子的。(以mobilenetV3的权值为例)
可以看出来,是用有序字典存储权值的。但是,一般在保存权值的时候,不会只是保存权值,还会保存训练的相关信息,比如训练次数epoch等。也就是说,会在外面再套一层字典,如下面所示。
知道了权值存储结构之后,接下来就好办了。
step2:编程实现
以mobilenetV3网络为例,我们在导入官方预训练权值的时候,只导入前面的网络层权重,不导入最后一层网络层的权重(因为最后一层网络层的神经元个数和你的分类类别数要相等,所以这一层不能随便用别人的预训练权重)。由于官方给的预训练权重每一层的网络名前面多了一个“module.”,所以我们还需要把每层网络名里面的“module.”给去掉,不然和我们自己定义的网络名不相等,会导致所有网络层的权值都导入不进,而且如果你的net.load_state_dict()里面的strict参数设置成False时,你还不自知,你以为你导入了权值,实际上你就是在从头训练。代码如下。
import torch
from mobilenetv3 import MobileNetV3_Small
from collections import OrderedDict
net = MobileNetV3_Small(num_classes=7)
model_weight_path = "./mbv3_small.pth.tar" #预训练权重
pre_weights = torch.load(model_weight_path)
print('已训练次数:',pre_weights['epoch'])
#----------------------------------------------------------------------------------------------------------#
#检查并打印网络层
new_state_dict=OrderedDict()
for k,v in pre_weights['state_dict'].items():
#print(k)
if k !='module.linear4.weight' and k !='module.linear4.bias':
name=k[7:] #去除网络名中的'module.'
new_state_dict[name]=v
print(name) #打印每一层的网络名
missing_keys, unexpected_keys = net.load_state_dict(new_state_dict, strict=False)
print('去除最后一层神经网络:',missing_keys)
print('意外的神经层:',unexpected_keys)
#---------------------------------------------------------------------------------------------------------#
运行结果如下图所示。
已训练次数: 196
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
bneck.0.se.se.1.weight
bneck.0.se.se.2.weight
bneck.0.se.se.2.bias
bneck.0.se.se.2.running_mean
bneck.0.se.se.2.running_var
bneck.0.se.se.2.num_batches_tracked
bneck.0.se.se.4.weight
bneck.0.se.se.5.weight
bneck.0.se.se.5.bias
bneck.0.se.se.5.running_mean
bneck.0.se.se.5.running_var
bneck.0.se.se.5.num_batches_tracked
bneck.0.conv1.weight
bneck.0.bn1.weight
bneck.0.bn1.bias
bneck.0.bn1.running_mean
bneck.0.bn1.running_var
bneck.0.bn1.num_batches_tracked
bneck.0.conv2.weight
bneck.0.bn2.weight
bneck.0.bn2.bias
bneck.0.bn2.running_mean
bneck.0.bn2.running_var
bneck.0.bn2.num_batches_tracked
bneck.0.conv3.weight
bneck.0.bn3.weight
bneck.0.bn3.bias
bneck.0.bn3.running_mean
bneck.0.bn3.running_var
bneck.0.bn3.num_batches_tracked
bneck.1.conv1.weight
bneck.1.bn1.weight
bneck.1.bn1.bias
bneck.1.bn1.running_mean
bneck.1.bn1.running_var
bneck.1.bn1.num_batches_tracked
bneck.1.conv2.weight
bneck.1.bn2.weight
bneck.1.bn2.bias
bneck.1.bn2.running_mean
bneck.1.bn2.running_var
bneck.1.bn2.num_batches_tracked
bneck.1.conv3.weight
bneck.1.bn3.weight
bneck.1.bn3.bias
bneck.1.bn3.running_mean
bneck.1.bn3.running_var
bneck.1.bn3.num_batches_tracked
bneck.2.conv1.weight
bneck.2.bn1.weight
bneck.2.bn1.bias
bneck.2.bn1.running_mean
bneck.2.bn1.running_var
bneck.2.bn1.num_batches_tracked
bneck.2.conv2.weight
bneck.2.bn2.weight
bneck.2.bn2.bias
bneck.2.bn2.running_mean
bneck.2.bn2.running_var
bneck.2.bn2.num_batches_tracked
bneck.2.conv3.weight
bneck.2.bn3.weight
bneck.2.bn3.bias
bneck.2.bn3.running_mean
bneck.2.bn3.running_var
bneck.2.bn3.num_batches_tracked
bneck.3.se.se.1.weight
bneck.3.se.se.2.weight
bneck.3.se.se.2.bias
bneck.3.se.se.2.running_mean
bneck.3.se.se.2.running_var
bneck.3.se.se.2.num_batches_tracked
bneck.3.se.se.4.weight
bneck.3.se.se.5.weight
bneck.3.se.se.5.bias
bneck.3.se.se.5.running_mean
bneck.3.se.se.5.running_var
bneck.3.se.se.5.num_batches_tracked
bneck.3.conv1.weight
bneck.3.bn1.weight
bneck.3.bn1.bias
bneck.3.bn1.running_mean
bneck.3.bn1.running_var
bneck.3.bn1.num_batches_tracked
bneck.3.conv2.weight
bneck.3.bn2.weight
bneck.3.bn2.bias
bneck.3.bn2.running_mean
bneck.3.bn2.running_var
bneck.3.bn2.num_batches_tracked
bneck.3.conv3.weight
bneck.3.bn3.weight
bneck.3.bn3.bias
bneck.3.bn3.running_mean
bneck.3.bn3.running_var
bneck.3.bn3.num_batches_tracked
bneck.4.se.se.1.weight
bneck.4.se.se.2.weight
bneck.4.se.se.2.bias
bneck.4.se.se.2.running_mean
bneck.4.se.se.2.running_var
bneck.4.se.se.2.num_batches_tracked
bneck.4.se.se.4.weight
bneck.4.se.se.5.weight
bneck.4.se.se.5.bias
bneck.4.se.se.5.running_mean
bneck.4.se.se.5.running_var
bneck.4.se.se.5.num_batches_tracked
bneck.4.conv1.weight
bneck.4.bn1.weight
bneck.4.bn1.bias
bneck.4.bn1.running_mean
bneck.4.bn1.running_var
bneck.4.bn1.num_batches_tracked
bneck.4.conv2.weight
bneck.4.bn2.weight
bneck.4.bn2.bias
bneck.4.bn2.running_mean
bneck.4.bn2.running_var
bneck.4.bn2.num_batches_tracked
bneck.4.conv3.weight
bneck.4.bn3.weight
bneck.4.bn3.bias
bneck.4.bn3.running_mean
bneck.4.bn3.running_var
bneck.4.bn3.num_batches_tracked
bneck.5.se.se.1.weight
bneck.5.se.se.2.weight
bneck.5.se.se.2.bias
bneck.5.se.se.2.running_mean
bneck.5.se.se.2.running_var
bneck.5.se.se.2.num_batches_tracked
bneck.5.se.se.4.weight
bneck.5.se.se.5.weight
bneck.5.se.se.5.bias
bneck.5.se.se.5.running_mean
bneck.5.se.se.5.running_var
bneck.5.se.se.5.num_batches_tracked
bneck.5.conv1.weight
bneck.5.bn1.weight
bneck.5.bn1.bias
bneck.5.bn1.running_mean
bneck.5.bn1.running_var
bneck.5.bn1.num_batches_tracked
bneck.5.conv2.weight
bneck.5.bn2.weight
bneck.5.bn2.bias
bneck.5.bn2.running_mean
bneck.5.bn2.running_var
bneck.5.bn2.num_batches_tracked
bneck.5.conv3.weight
bneck.5.bn3.weight
bneck.5.bn3.bias
bneck.5.bn3.running_mean
bneck.5.bn3.running_var
bneck.5.bn3.num_batches_tracked
bneck.6.se.se.1.weight
bneck.6.se.se.2.weight
bneck.6.se.se.2.bias
bneck.6.se.se.2.running_mean
bneck.6.se.se.2.running_var
bneck.6.se.se.2.num_batches_tracked
bneck.6.se.se.4.weight
bneck.6.se.se.5.weight
bneck.6.se.se.5.bias
bneck.6.se.se.5.running_mean
bneck.6.se.se.5.running_var
bneck.6.se.se.5.num_batches_tracked
bneck.6.conv1.weight
bneck.6.bn1.weight
bneck.6.bn1.bias
bneck.6.bn1.running_mean
bneck.6.bn1.running_var
bneck.6.bn1.num_batches_tracked
bneck.6.conv2.weight
bneck.6.bn2.weight
bneck.6.bn2.bias
bneck.6.bn2.running_mean
bneck.6.bn2.running_var
bneck.6.bn2.num_batches_tracked
bneck.6.conv3.weight
bneck.6.bn3.weight
bneck.6.bn3.bias
bneck.6.bn3.running_mean
bneck.6.bn3.running_var
bneck.6.bn3.num_batches_tracked
bneck.6.shortcut.0.weight
bneck.6.shortcut.1.weight
bneck.6.shortcut.1.bias
bneck.6.shortcut.1.running_mean
bneck.6.shortcut.1.running_var
bneck.6.shortcut.1.num_batches_tracked
bneck.7.se.se.1.weight
bneck.7.se.se.2.weight
bneck.7.se.se.2.bias
bneck.7.se.se.2.running_mean
bneck.7.se.se.2.running_var
bneck.7.se.se.2.num_batches_tracked
bneck.7.se.se.4.weight
bneck.7.se.se.5.weight
bneck.7.se.se.5.bias
bneck.7.se.se.5.running_mean
bneck.7.se.se.5.running_var
bneck.7.se.se.5.num_batches_tracked
bneck.7.conv1.weight
bneck.7.bn1.weight
bneck.7.bn1.bias
bneck.7.bn1.running_mean
bneck.7.bn1.running_var
bneck.7.bn1.num_batches_tracked
bneck.7.conv2.weight
bneck.7.bn2.weight
bneck.7.bn2.bias
bneck.7.bn2.running_mean
bneck.7.bn2.running_var
bneck.7.bn2.num_batches_tracked
bneck.7.conv3.weight
bneck.7.bn3.weight
bneck.7.bn3.bias
bneck.7.bn3.running_mean
bneck.7.bn3.running_var
bneck.7.bn3.num_batches_tracked
bneck.8.se.se.1.weight
bneck.8.se.se.2.weight
bneck.8.se.se.2.bias
bneck.8.se.se.2.running_mean
bneck.8.se.se.2.running_var
bneck.8.se.se.2.num_batches_tracked
bneck.8.se.se.4.weight
bneck.8.se.se.5.weight
bneck.8.se.se.5.bias
bneck.8.se.se.5.running_mean
bneck.8.se.se.5.running_var
bneck.8.se.se.5.num_batches_tracked
bneck.8.conv1.weight
bneck.8.bn1.weight
bneck.8.bn1.bias
bneck.8.bn1.running_mean
bneck.8.bn1.running_var
bneck.8.bn1.num_batches_tracked
bneck.8.conv2.weight
bneck.8.bn2.weight
bneck.8.bn2.bias
bneck.8.bn2.running_mean
bneck.8.bn2.running_var
bneck.8.bn2.num_batches_tracked
bneck.8.conv3.weight
bneck.8.bn3.weight
bneck.8.bn3.bias
bneck.8.bn3.running_mean
bneck.8.bn3.running_var
bneck.8.bn3.num_batches_tracked
bneck.9.se.se.1.weight
bneck.9.se.se.2.weight
bneck.9.se.se.2.bias
bneck.9.se.se.2.running_mean
bneck.9.se.se.2.running_var
bneck.9.se.se.2.num_batches_tracked
bneck.9.se.se.4.weight
bneck.9.se.se.5.weight
bneck.9.se.se.5.bias
bneck.9.se.se.5.running_mean
bneck.9.se.se.5.running_var
bneck.9.se.se.5.num_batches_tracked
bneck.9.conv1.weight
bneck.9.bn1.weight
bneck.9.bn1.bias
bneck.9.bn1.running_mean
bneck.9.bn1.running_var
bneck.9.bn1.num_batches_tracked
bneck.9.conv2.weight
bneck.9.bn2.weight
bneck.9.bn2.bias
bneck.9.bn2.running_mean
bneck.9.bn2.running_var
bneck.9.bn2.num_batches_tracked
bneck.9.conv3.weight
bneck.9.bn3.weight
bneck.9.bn3.bias
bneck.9.bn3.running_mean
bneck.9.bn3.running_var
bneck.9.bn3.num_batches_tracked
bneck.10.se.se.1.weight
bneck.10.se.se.2.weight
bneck.10.se.se.2.bias
bneck.10.se.se.2.running_mean
bneck.10.se.se.2.running_var
bneck.10.se.se.2.num_batches_tracked
bneck.10.se.se.4.weight
bneck.10.se.se.5.weight
bneck.10.se.se.5.bias
bneck.10.se.se.5.running_mean
bneck.10.se.se.5.running_var
bneck.10.se.se.5.num_batches_tracked
bneck.10.conv1.weight
bneck.10.bn1.weight
bneck.10.bn1.bias
bneck.10.bn1.running_mean
bneck.10.bn1.running_var
bneck.10.bn1.num_batches_tracked
bneck.10.conv2.weight
bneck.10.bn2.weight
bneck.10.bn2.bias
bneck.10.bn2.running_mean
bneck.10.bn2.running_var
bneck.10.bn2.num_batches_tracked
bneck.10.conv3.weight
bneck.10.bn3.weight
bneck.10.bn3.bias
bneck.10.bn3.running_mean
bneck.10.bn3.running_var
bneck.10.bn3.num_batches_tracked
conv2.weight
bn2.weight
bn2.bias
bn2.running_mean
bn2.running_var
bn2.num_batches_tracked
linear3.weight
linear3.bias
bn3.weight
bn3.bias
bn3.running_mean
bn3.running_var
bn3.num_batches_tracked
去除最后一层神经网络: ['linear4.weight', 'linear4.bias']
意外的神经层: []