Pytorch学习(6) —— 加载模型部分参数的用法

上一节,我们给出了模型加载和保存的简要示例,但是,我们有时候会用别人的参数,他们的层参数名和我们的名称很容易不同,因此这里将会对源码进入深入剖析,分析参数提取和保存是如何实现的。

我们使用pytorch的VGG16预训练模型,加载,返回其类型。可以发现,是OrderedDict类型,也就是字典类型,既然是字典,每个层的参数就是用了一个键值对保存起来了。

model = torch.load('vgg16-397923af.pth')
list_keys = list(model.keys()) # 将模型中的keys转换为list
print(type(model))
print(list_keys)
print(type(model[list_keys[0]]))
输出:
<class 'collections.OrderedDict'>
['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
<class 'torch.Tensor'>

很明显,这个模型数据里面就是这个模型所需的参数,每个参数用一个键值对存储,每个参数都是一个Tensor矩阵。

关于这个参数的命名,因为这里面的Pytorch使用一个Sequence存的,所以这个参数名的命名规则就是. Sequence变量名+第几层+每层内部的参数名

下面我给出个例子,说明如何将这些参数拷贝到自己的模型上,下面自己写了一个VGG模型

class VGG(nn.Module):
    def __init__(self, num_classes=100):
        super(VGG, self).__init__()
        layers = nn.ModuleList()
        in_dim = 3
        out_dim = 64
        for i in range(13):
            layers.extend([nn.Conv2d(in_dim, out_dim, 3, 1, 1),
                           nn.ReLU(inplace=True)])
            out_dim = in_dim
            if i == 1 or i == 3 or i == 6 or i == 9 or i == 12:
                layers.append(nn.MaxPool2d(2,2))
                if i != 9:
                    out_dim *= 2
        self.fea = nn.Sequential(layers)
        self.cls = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),
            nn.Linear(4096, num_classes),
            nn.Softmax(num_classes)
        )

        def forward(self, x):
            x = self.fea(x)
            x = x.view(x.size(0), -1)
            x = self.cls(x)
            return x

实例化一个VGG,并输出这个模型所含的参数

vgg = VGG()
print(list(vgg.state_dict().keys()))
输出:
['fea.0.0.weight', 'fea.0.0.bias', 'fea.0.2.weight', 'fea.0.2.bias', 'fea.0.5.weight', 'fea.0.5.bias', 'fea.0.7.weight', 'fea.0.7.bias', 'fea.0.10.weight', 'fea.0.10.bias', 'fea.0.12.weight', 'fea.0.12.bias', 'fea.0.14.weight', 'fea.0.14.bias', 'fea.0.17.weight', 'fea.0.17.bias', 'fea.0.19.weight', 'fea.0.19.bias', 'fea.0.21.weight', 'fea.0.21.bias', 'fea.0.24.weight', 'fea.0.24.bias', 'fea.0.26.weight', 'fea.0.26.bias', 'fea.0.28.weight', 'fea.0.28.bias', 'cls.0.weight', 'cls.0.bias', 'cls.3.weight', 'cls.3.bias', 'cls.6.weight', 'cls.6.bias']

根据之前的博客Pytorch学习(2) —— 网络工具箱 TORCH.NN 基本类用法,我们使用load_state_dict进行模型加载

比如用下面的方法可以将另一个模型的参数转到自己的参数上,记住strict一定要设置为false,否则会出错。

vgg.load_state_dict({'fea.0.0.weight':model['features.0.weight']}, strict=False)

总结

本部分介绍了如何将预训练的模型参数加载到自己的模型上,有时候我们的网络参数是由两个其他网络构成,那么本部分提供了一种加载方法。

至此,模型的加载用法已经完成,下面就开始介绍如何构建模型。

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch是一个功能强大的机器学习框架。它使用动态计算图和高效的自动微分来加速深度学习。在实际编码的过程中,我们经常会使用预训练模型来加速模型训练和进一步提升模型准确率,不过一些时候我们并不需要整个预训练模型的所有参数来进行训练,而是只需要加载预训练模型部分参数。那么在PyTorch中,我们要如何来加载预训练模型部分参数呢? 要想加载预训练模型部分参数,在PyTorch中,我们可以使用load_state_dict()函数实现。load_state_dict()函数在PyTorch中是将参数拷贝到新模型中的函数,新模型和预训练模型的网络结构应该是相同的。然后我们可以通过load_state_dict()函数的参数prefix和exclude来实现部分参数加载。prefix参数是指定了预训练模型中需要加载参数的前缀,而exclude参数是指定了我们不需要加载参数。 例如,我们有一个预训练模型‘resnet18.pth’,它包含了resnet18模型在imagenet上训练好的模型参数。我们想要使用这个模型来进行一些迁移学习,那只需要加载resnet18最后一层fc层之前的所有模型参数,而不需要加载最后一层fc层的权重。那么,我们可以通过以下代码来实现: ``` import torch.utils.model_zoo as model_zoo import torchvision.models as models # 定义一个resnet18模型 resnet18 = models.resnet18(pretrained=False) # 加载预训练模型的所有参数 model_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' resnet18.load_state_dict(model_zoo.load_url(model_url)) # 获取所有要加载参数的名字 params_to_update = [] for name, param in resnet18.named_parameters(): if 'fc' not in name: params_to_update.append(name) # 加载部分预训练模型参数 state_dict = model_zoo.load_url(model_url) model_dict = resnet18.state_dict() for name, value in state_dict.items(): if name.startswith(tuple(params_to_update)): model_dict.update({name: value}) resnet18.load_state_dict(model_dict) ``` 上述代码先是定义了一个resnet18模型,然后加载resnet18预训练模型的所有参数。通过获取所有需要加载参数的名字,然后将其加载到新模型中,从而实现了加载预训练模型部分参数的目的。 总结: 通过使用load_state_dict()函数的prefix和exclude参数,在PyTorch中实现了对预训练模型部分参数加载。这将使我们在使用预训练模型时更加灵活和高效。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值