pytorch第07天 迁移学习实验

1 pytorch中迁移学习概述

所谓迁移学习,就是将已经预训练好的模型或者模型参数,导入当前的程序程序中。pytorch中实现迁移学习,可以有两种方法:
1 直接导入整个模型;
2 先创建一个模型对象,然后再导入模型参数(即模型的状态字典)。
我们今天就做几个实验对比一下两者之间的区别

2 建立一个预训练的模型文件与参数文件

这里我们建立一个神经网络模型,并实例化,然后将模型文件和参数文件保存

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

"""建立模型"""
# 搭建神经网络(定义类)
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # 3表示输入数据的通道, 6 表示输出的通道, 5表示卷积核的宽度
        self.pool = nn.MaxPool2d(2, 2)  # 池化窗口使用(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 定义全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 一口气完成卷积、激活、池化
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 实例化神经网络模型
model = Net()

# 保存模型和模型参数g
torch.save(model, './model.pth')
torch.save(model.state_dict(), './model_weights.pth')

# 我们仅仅测试迁移学习,这里就不对模型参数进行训练

程序运行后,目录结构中多出两个文件,分别代表模型和参数(权重)
在这里插入图片描述

2 实验

我们实验从其他文件中导入刚刚保存的模型文件和参数文件,下文所称的组件,指的是 卷积层、全连接层、池化层 等。

实验1:不写类,能否直接导入模型文件和参数文件

# coding=utf-8
import torch

try:
    model_weights = torch.load('./model_weights.pth')
    # 'model_weights.pth'是另一个文件中 Net 类对象的参数文件
    print('可以在定义类之前导入参数文件')
except:
    print('不能在定义类之前导入参数文件')

try:
    model = torch.load('model.pth')
    # 'model.pth'是用另一个文件中 Net 类对象保存的模型文件
    print('可以在定义类之前导入模型文件')
    print(type(model))
except:
    print('不能在定义类之前导入模型文件')

输出

可以在定义类之前导入参数文件
不能在定义类之前导入模型文件

实验2:类名不一致,但组件名称一致,能否导入参数文件

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义NeuralNetwork类
# NeuralNetwork 与 Net 的结构和各个组件名完全一样
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # 3表示输入数据的通道, 6 表示输出的通道, 5表示卷积核的宽度
        self.pool = nn.MaxPool2d(2, 2)  # 池化窗口使用(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 定义全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 一口气完成卷积、激活、池化
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 神经网络实例化
model2 = NeuralNetwork()

try:
    # 导入权重文件
    model2.load_state_dict(torch.load('model_weights.pth'))
    print("即便类名不一致,但只要组件名称一致,就能顺利导入参数文件")
except:
    print("类名不一致,不能导入参数文件")

输出

即便类名不一致,但只要组件名称一致,就能顺利导入参数文件

实验3:类名不一致,不带参数的组件名称不一致,但带参数的组件一致,能否导入参数文件

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义NeuralNetwork类
# NeuralNetwork2 与 Net 的结构一致,带参数的组件也一致,但不带参数的组件名不一致
class NeuralNetwork2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool2 = nn.MaxPool2d(2, 2)  # 在Net中,池化层的名称为 self.pool
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool2(F.relu(self.conv1(x))) # 这里也要做相应的更改
        x = self.pool2(F.relu(self.conv2(x))) # 这里也要做相应的更改
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 神经网络实例化,并打印模型结构
model3 = NeuralNetwork2()


try:
    # 导入权重文件
    model3.load_state_dict(torch.load('model_weights.pth'))
    print("即便类名不一致,不带参数的组件名称也不一致,"
          "但只要带参数的组件名称一致,就能顺利导入参数文件")
except:
    print("只要有组件名称不一致,就不能导入参数文件")

输出

即便类名不一致,不带参数的组件名称也不一致,但只要带参数的组件名称一致,就能顺利导入参数文件

实验4:带参数的组件名称不一致,能否导入参数文件

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义类
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool2 = nn.MaxPool2d(2, 2)  # 在原始Net中,池化层的名称为 self.pool
        self.final_conv = nn.Conv2d(6, 16, 5)   # 在原始Net中,这一层为 self.conv2
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.final_fc = nn.Linear(84, 10)       # 在原始Net中,这一层为 self.fc3

    def forward(self, x):
        x = self.pool2(F.relu(self.conv1(x))) # 这里也要做相应的更改
        x = self.pool2(F.relu(self.conv2(x))) # 这里也要做相应的更改
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.final_fc(x)            # 这里也要做相应的更改
        return x


model4 = Net()
try:
    # 导入权重文件
    model4.load_state_dict(torch.load('model_weights.pth'))
    print("带参数的组件名称不一致,也能顺利导入参数文件")
except:
    print("带参数的组件名称不一致,不能导入参数文件")

输出

带参数的组件名称不一致,不能导入参数文件

实验5:定义类之后,类名不一致,但各组件名称一致,能否导入模型文件

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义类
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # 3表示输入数据的通道, 6 表示输出的通道, 5表示卷积核的宽度
        self.pool = nn.MaxPool2d(2, 2)  # 池化窗口使用(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 定义全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 一口气完成卷积、激活、池化
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


try:
    model5 = torch.load('model.pth')
    print(type(model5))
    print("定义相关类之后,如果类的结构一致,但类名不一致,也能导入模型文件")
except:
    print("定义类之后,如果类名不一致,就不能导入模型文件")

输出:

定义类之后,如果类名不一致,就不能导入模型文件

实验6:定义类之后,类名一致,但组件名称不一致,能否导入模型文件

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义类
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool2 = nn.MaxPool2d(2, 2)  # 在原始Net中,池化层的名称为 self.pool
        self.final_conv = nn.Conv2d(6, 16, 5)   # 在原始Net中,这一层为 self.conv2
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.final_fc = nn.Linear(84, 10)       # 在原始Net中,这一层为 self.fc3

    def forward(self, x):
        x = self.pool2(F.relu(self.conv1(x))) # 这里也要做相应的更改
        x = self.pool2(F.relu(self.conv2(x))) # 这里也要做相应的更改
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.final_fc(x)            # 这里也要做相应的更改
        return x

try:
    model6 = torch.load('model.pth')
    print(type(model6))
    print("定义相关类之后,类名一致,但组件名称不一致,也能导入模型文件")
except:
    print("定义类之后,即便类名一致,但组件名称不一致,就不能导入模型文件")

输出

<class '__main__.Net'>
定义相关类之后,类名一致,但组件名称不一致,也能导入模型文件

总结

综合上述实验,可以得到下面两个结论:
(1)导入参数文件,和类名无关,只与类的组件名有关,组件名必须一致;
(2)导入模型文件,和类的组件名无关,只与类名有关,类名必须一致。
当然,上述两条结论成立的前提是结构必须一致。(所谓结构一致,假如 model.pth 是Net类的对象保存的模型文件,参数文件为 model_weights.pth;若要导入将模型文件导入,则本地也要写一个Net类,且结构一致;若要将参数文件导入到本地模型,则本地模型类也要和Net结构一致)

3 pytorch中迁移学习的底层机制

(1)状态字典中的键值对

pytorch官方文档中,推荐使用导入参数文件的方式来导入模型(具体原因我也还没搞明白,可能由于是参数文件的存储空间比较小),这种方式的底层机制是导入状态字典。PyTorch模型将学习到(或者要学习)的参数存储在一个内部状态字典中,称为state_dict,所谓的保存参数,实际上就是保存这个模型的状态字典。
状态字典的类型为collections.OrderedDict类型,即有序字典,关于有序字典与普通字典区别,可以看这篇文章:https://www.cnblogs.com/lowmanisbusy/p/10257360.html

模型中的状态字典,其键为模型中各个带参数的组件名,值为这个组件的参数。

# coding=utf-8
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

# 搭建神经网络(定义类)
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 4)
        self.ccrbm = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, 2),
        )

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.ccrbm(x)
        x = self.bn(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = F.batch_norm(x)
        x = self.relu(x)
        return x

# 实例化神经网络模型
model = Net()
state_dict = model.state_dict()
for k, v in state_dict.items():
    print(k, '\t', v.size())

输出

conv.weight 	 torch.Size([3, 3, 4, 4])
conv.bias 	 torch.Size([3])
ccrbm.0.weight 	 torch.Size([6, 3, 5, 5])
ccrbm.0.bias 	 torch.Size([6])
ccrbm.1.weight 	 torch.Size([16, 6, 5, 5])
ccrbm.1.bias 	 torch.Size([16])
ccrbm.3.weight 	 torch.Size([16])
ccrbm.3.bias 	 torch.Size([16])
ccrbm.3.running_mean 	 torch.Size([16])
ccrbm.3.running_var 	 torch.Size([16])
ccrbm.3.num_batches_tracked 	 torch.Size([])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([10, 120])
fc2.bias 	 torch.Size([10])

ccrbm是nn.Sequential类的对象,容器内部封装了5层,ccrbm后面的数字,表示在ccrbm内部属于第几层。

我们在 Net 类中定义了5个组件,分别为 conv,ccrbm,fc1,fc2,relu,但从输出中可以看到,状态字典中只存储了 conv,ccrbm,fc1,fc2 这三个组件的参数,并且ccrbm内的5层中,只有第0,1,3层有参数,所以状态字典中只存储这三层。

另外,在 forward() 方法中,也有BN层,但状态字典中却没收录对应的参数,这是因为这个BN层不是 Net 类的成员变量,也就是说,它不是Net类的组件,pytorch 只把模块组件中的参数当成要更新的参数。

(2)导入状态字典到当前模型

我们现在知道了模型状态字典的键值对是什么了,就能理解为什么“导入参数文件,和类名无关,只与类的组件名有关,组件名必须一致”这句话了

new_model.load_state_dict(torch.load('model_weights.pth'))

上面这条语句,是把参数文件(即状态字典)导入到新模型中,由于.load_state_dict方法不能直接导入文件,所以要先用torch.load把权重文件导入到内存中。

4 迁移学习的两种场景

上面我们是从工具(pytorch)角度讲了迁移学习的两种实现方法,现在我们来讲一下迁移学习的两种用法:

(1)重新训练全连接层

将卷积层(包括BN层)的参数固定,作为特征提取器,然后重新定义并训练全连接层。这种用法往往是因为自己的网络和别人的网络输出长度不一样,比如ResNet18的输出长度是10,可以用于10分类,但我自己的任务是二分类,那么就必须重新定义全连接层,让其出去长度为2,然后冻结卷积层的参数,只训练全连接层。

修改全连接层的代码如下:

import torchvision

model_conv = torchvision.models.resnet18(pretrained=True)	# 导入模型,这里torchvision中集成了
for param in model_conv.parameters():					# 冻结各层的参数
    param.requires_grad = False

num_ftrs = model_conv.fc.in_features		# 拿到全连接层的输入特征数
model_conv.fc = nn.Linear(num_ftrs, 2)		# 重新定义全连接层

接下来的过程与前面一样,定义优化器、损失函数,训练等。

如何获取某一层的输入特征数,或者输入通道数、卷积核大小这些信息?
以上面定义的model_conv为例,我们来获取指定层的卷积核大小
先打印model_conv的结构

print(model_conv)

当然,结构太长,我们只截取一段图片
在这里插入图片描述
假设我们要拿到截图中红色方框中的信息,即 layer4——0——conv1 的卷积和大小,可以使用下面的语句

print(model_conv.layer4[0].conv1.kernel_size)		# 按照蓝色方框一个一个地索引,遇见纯数字则用方括号

输出

(3, 3)

(2)微调整个网络

这是将预训练模型的参数作为当前模型的初始化参数,然后训练整个模型,上面的方法只是训练全连接层,这个方法是训练整个网络,因此会慢很多。这种方法一般是用在预训练模型的特征提取能力不够理想,或者预训练网络对当前任务的特征提取能力不足。当然,既然是微调,迭代次数毕竟不会太多。
关于迁移学习的两种用法,可以看这个教程:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,实现迁移学习的方法有两种。一种是微调网络的方法,即更改最后一层全连接,并且微调训练网络。另一种是将模型看作特征提取器,冻结所有层并且更改最后一层,只训练最后一层。这样可以快速训练模型而准确率不低于自己训练的模型。 在实施迁移学习之前,我们需要准备数据并选择合适的模型。数据的准备包括选择数据增广的方式,而模型的选择可以使用PyTorch提供的预训练模型,如VGG16等。 在使用PyTorch进行迁移学习时,我们可以使用torchvision.models中的预训练模型。例如,可以使用models.vgg16(pretrained=True)来加载在ImageNet数据集上预训练的VGG16模型。然后,我们可以通过设置每个参数的requires_grad属性为False来冻结所有层,使其参数不会更新。 以上是关于在PyTorch中实现迁移学习的基本步骤和方法。具体的实现细节可以根据具体的需求和问题进行调整和修改。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [【Pytorch迁移学习(Transfer Learning)](https://blog.csdn.net/m0_51941269/article/details/128258212)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [PyTorch使用教程-迁移学习(几分钟即可训练好自己的模型)](https://blog.csdn.net/weixin_42263486/article/details/108302350)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值