Pytorch学习(6) —— 加载模型部分参数的用法

上一节,我们给出了模型加载和保存的简要示例,但是,我们有时候会用别人的参数,他们的层参数名和我们的名称很容易不同,因此这里将会对源码进入深入剖析,分析参数提取和保存是如何实现的。

我们使用pytorch的VGG16预训练模型,加载,返回其类型。可以发现,是OrderedDict类型,也就是字典类型,既然是字典,每个层的参数就是用了一个键值对保存起来了。

model = torch.load('vgg16-397923af.pth')
list_keys = list(model.keys()) # 将模型中的keys转换为list
print(type(model))
print(list_keys)
print(type(model[list_keys[0]]))
输出:
<class 'collections.OrderedDict'>
['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
<class 'torch.Tensor'>

很明显,这个模型数据里面就是这个模型所需的参数,每个参数用一个键值对存储,每个参数都是一个Tensor矩阵。

关于这个参数的命名,因为这里面的Pytorch使用一个Sequence存的,所以这个参数名的命名规则就是. Sequence变量名+第几层+每层内部的参数名

下面我给出个例子,说明如何将这些参数拷贝到自己的模型上,下面自己写了一个VGG模型

class VGG(nn.Module):
    def __init__(self, num_classes=100):
        super(VGG, self).__init__()
        layers = nn.ModuleList()
        in_dim = 3
        out_dim = 64
        for i in range(13):
            layers.extend([nn.Conv2d(in_dim, out_dim, 3, 1, 1),
                           nn.ReLU(inplace=True)])
            out_dim = in_dim
            if i == 1 or i == 3 or i == 6 or i == 9 or i == 12:
                layers.append(nn.MaxPool2d(2,2))
                if i != 9:
                    out_dim *= 2
        self.fea = nn.Sequential(layers)
        self.cls = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, num_classes),
            nn.Softmax(num_classes)
        )

        def forward(self, x):
            x = self.fea(x)
            x = x.view(x.size(0), -1)
            x = self.cls(x)
            return x

实例化一个VGG,并输出这个模型所含的参数

vgg = VGG()
print(list(vgg.state_dict().keys()))
输出:
['fea.0.0.weight', 'fea.0.0.bias', 'fea.0.2.weight', 'fea.0.2.bias', 'fea.0.5.weight', 'fea.0.5.bias', 'fea.0.7.weight', 'fea.0.7.bias', 'fea.0.10.weight', 'fea.0.10.bias', 'fea.0.12.weight', 'fea.0.12.bias', 'fea.0.14.weight', 'fea.0.14.bias', 'fea.0.17.weight', 'fea.0.17.bias', 'fea.0.19.weight', 'fea.0.19.bias', 'fea.0.21.weight', 'fea.0.21.bias', 'fea.0.24.weight', 'fea.0.24.bias', 'fea.0.26.weight', 'fea.0.26.bias', 'fea.0.28.weight', 'fea.0.28.bias', 'cls.0.weight', 'cls.0.bias', 'cls.3.weight', 'cls.3.bias', 'cls.6.weight', 'cls.6.bias']

根据之前的博客Pytorch学习(2) —— 网络工具箱 TORCH.NN 基本类用法,我们使用load_state_dict进行模型加载

比如用下面的方法可以将另一个模型的参数转到自己的参数上,记住strict一定要设置为false,否则会出错。

vgg.load_state_dict({'fea.0.0.weight':model['features.0.weight']}, strict=False)

总结

本部分介绍了如何将预训练的模型参数加载到自己的模型上,有时候我们的网络参数是由两个其他网络构成,那么本部分提供了一种加载方法。

至此,模型的加载用法已经完成,下面就开始介绍如何构建模型。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值