上一节,我们给出了模型加载和保存的简要示例,但是,我们有时候会用别人的参数,他们的层参数名和我们的名称很容易不同,因此这里将会对源码进入深入剖析,分析参数提取和保存是如何实现的。
我们使用pytorch的VGG16预训练模型,加载,返回其类型。可以发现,是OrderedDict类型,也就是字典类型,既然是字典,每个层的参数就是用了一个键值对保存起来了。
model = torch.load('vgg16-397923af.pth')
list_keys = list(model.keys()) # 将模型中的keys转换为list
print(type(model))
print(list_keys)
print(type(model[list_keys[0]]))
输出:
<class 'collections.OrderedDict'>
['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
<class 'torch.Tensor'>
很明显,这个模型数据里面就是这个模型所需的参数,每个参数用一个键值对存储,每个参数都是一个Tensor矩阵。
关于这个参数的命名,因为这里面的Pytorch使用一个Sequence存的,所以这个参数名的命名规则就是. Sequence变量名+第几层+每层内部的参数名
下面我给出个例子,说明如何将这些参数拷贝到自己的模型上,下面自己写了一个VGG模型
class VGG(nn.Module):
def __init__(self, num_classes=100):
super(VGG, self).__init__()
layers = nn.ModuleList()
in_dim = 3
out_dim = 64
for i in range(13):
layers.extend([nn.Conv2d(in_dim, out_dim, 3, 1, 1),
nn.ReLU(inplace=True)])
out_dim = in_dim
if i == 1 or i == 3 or i == 6 or i == 9 or i == 12:
layers.append(nn.MaxPool2d(2,2))
if i != 9:
out_dim *= 2
self.fea = nn.Sequential(layers)
self.cls = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),
nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),
nn.Linear(4096, num_classes),
nn.Softmax(num_classes)
)
def forward(self, x):
x = self.fea(x)
x = x.view(x.size(0), -1)
x = self.cls(x)
return x
实例化一个VGG,并输出这个模型所含的参数
vgg = VGG()
print(list(vgg.state_dict().keys()))
输出:
['fea.0.0.weight', 'fea.0.0.bias', 'fea.0.2.weight', 'fea.0.2.bias', 'fea.0.5.weight', 'fea.0.5.bias', 'fea.0.7.weight', 'fea.0.7.bias', 'fea.0.10.weight', 'fea.0.10.bias', 'fea.0.12.weight', 'fea.0.12.bias', 'fea.0.14.weight', 'fea.0.14.bias', 'fea.0.17.weight', 'fea.0.17.bias', 'fea.0.19.weight', 'fea.0.19.bias', 'fea.0.21.weight', 'fea.0.21.bias', 'fea.0.24.weight', 'fea.0.24.bias', 'fea.0.26.weight', 'fea.0.26.bias', 'fea.0.28.weight', 'fea.0.28.bias', 'cls.0.weight', 'cls.0.bias', 'cls.3.weight', 'cls.3.bias', 'cls.6.weight', 'cls.6.bias']
根据之前的博客Pytorch学习(2) —— 网络工具箱 TORCH.NN 基本类用法,我们使用load_state_dict进行模型加载
比如用下面的方法可以将另一个模型的参数转到自己的参数上,记住strict一定要设置为false,否则会出错。
vgg.load_state_dict({'fea.0.0.weight':model['features.0.weight']}, strict=False)
总结
本部分介绍了如何将预训练的模型参数加载到自己的模型上,有时候我们的网络参数是由两个其他网络构成,那么本部分提供了一种加载方法。
至此,模型的加载用法已经完成,下面就开始介绍如何构建模型。