第九章 预训练模型与自己模型参数不匹配和模型微调的具体实现(工具)

117 篇文章 2 订阅
58 篇文章 0 订阅

导入预训练模型在通常情况下都能加快模型收敛,提升模型性能。但根据实际任务需求,自己搭建的模型往往和通用的Backbone并不能做到网络层的完全一致,无非就是少一些层和多一些层两种情况。

1. 自己模型层数较少

net = ...   # net为自己的模型
save_model = torch.load('path_of_pretrained_model') # 获取预训练模型字典(键值对)
model_dict = net.state_dict() # 获取自己模型字典(键值对)
# 新定义字典,用来获取自己模型中对应层的预训练模型中的参数
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} 
model_dict.update(state_dict) # 更新自己模型字典中键值对
net.load_state_dict(model_dict) # 加载参数

其中update:对于state_dict和model_dict都有的键值对,前者对应的值会替换后者,若前者有后者没有的键值对,则会添加这些键值对到后者字典中。

2. 自己模型层数较多

model_dict = model.state_dict() # 自己模型字典
model_dict.update(pretrained_model) # 直接将预训练模型参数更新进来
model.load_state_dict(model_dict) # 加载

结果与预训练模型对应的层权重被加载了,其它层则为默认初始化。

由于模型参数都为字典形式存在,可以用字典的增删方式进行更灵活的操作

特征提取与微调

迁移学习是一种有效的机器学习方法,尤其是在数据不足的情况下。它通过使用在大型数据集(如 ImageNet)上预训练的模型,将这些模型的知识应用于新的、相对较小的数据集上。下面,我将详细解释两种主要的迁移学习策略,并提供相应的代码示例。

1. 特征提取(Feature Extraction)

在这种策略中,你使用预训练模型的卷积层(即模型的前几层)来作为固定的特征提取器。然后,你只需添加一些新的可训练层(通常是全连接层),以便根据新的数据集进行预测。

代码示例

假设我们要在一个新的数据集上使用预训练的 ResNet 模型进行特征提取:

import torch
import torch.nn as nn
from torchvision import models

# 加载预训练的 ResNet 模型
model = models.resnet18(pretrained=True)
# 冻结所有卷积层的参数
for param in model.parameters():
    param.requires_grad = False

# 替换最后的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 假设我们的新数据集有2个类别
# 现在,只有 model.fc 层的参数会在训练中更新

在这个示例中,我们首先加载了预训练的 ResNet 模型,并冻结了它的所有卷积层。然后,我们替换了最后的全连接层以适应新的分类任务(假设有两个类别)。在训练过程中,只有这个新添加的全连接层的参数会更新。

2. 微调(Fine-Tuning)

微调涉及到在预训练模型的基础上进一步训练整个模型(或模型的大部分层)。这通常是在特征提取的基础上进行的,即首先使用特征提取策略训练模型,然后解冻整个模型的多个层,并对整个模型进行进一步的训练。

代码示例

继续使用上面的 ResNet 示例,我们现在将进行微调:

# 之前的代码省略...
# 解冻所有层
for param in model.parameters():
    param.requires_grad = True
# 现在整个模型的参数都会在训练中更新

在微调阶段,所有层的参数都会被更新。这通常是在特征提取阶段之后进行,特别是当新数据集与原始数据集在特征上有较大差异时。

添加、删除或修改字典中的项

在 PyTorch 中,模型的参数是以字典形式存储的,这使得我们可以利用 Python 字典的特性来灵活地处理模型参数。例如,我们可以添加、删除或修改字典中的项,以此来自定义模型的参数。以下是一些常见操作的示例:

修改

示例:修改特定层的参数

假设您想要修改预训练模型中某一层的参数,比如将第一层卷积层的权重全部设置为零。

import torch
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)
# 将第一层卷积层的权重设置为零
torch.nn.init.constant_(model.conv1.weight, 0)

在这个示例中,我们使用 torch.nn.init.constant_ 方法直接修改了模型的第一层卷积层(conv1)的权重,将其全部设置为零。

示例:修改层的属性

除了修改层的参数外,您还可以修改层的其他属性。例如,您可能想要更改卷积层的步长(stride)或填充(padding)。

# 改变第一层卷积层的步长和填充
model.conv1.stride = (2, 2)
model.conv1.padding = (1, 1)

在这个示例中,我们修改了模型的第一层卷积层的步长和填充属性。

注意事项

  • 在修改模型的参数或属性时,请确保您的修改不会导致模型架构的不一致。例如,更改卷积层的核大小或步长可能会影响模型中后续层的输入尺寸。
  • 修改模型参数通常需要对模型的工作原理有深入的理解,以避免意外地破坏模型的性能。

通过这些修改,您可以对模型进行微调,使其更好地适应特定的任务或数据集。这种能力在进行模型实验和优化时非常有价值。

示例 1:删除特定层的参数

假设你想要从预训练模型中删除某些层的参数。这可以通过删除字典中相应的键值对来实现。

import torch
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)

# 获取模型的 state_dict
model_dict = model.state_dict()

# 假设我们想要删除第一层卷积层的参数
del model_dict['conv1.weight']
del model_dict['conv1.bias']

# 更新模型的 state_dict
model.load_state_dict(model_dict, strict=False)

在这个示例中,我们首先加载了预训练的 ResNet18 模型,并获取其状态字典。然后,我们删除了第一层卷积层(conv1)的权重和偏置参数。最后,我们使用 load_state_dict 更新模型的参数,strict=False 表示我们允许不匹配的项存在。

示例 2:添加新层的参数

如果你想向模型中添加新的层,并为其初始化参数,可以直接向状态字典中添加新的键值对。

import torch.nn as nn

# 继续之前的模型
# 添加一个新的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
               nn.Linear(num_ftrs, 500),
               nn.ReLU(),
               nn.Linear(500, 2)
           )

# 初始化新层的参数
nn.init.xavier_normal_(model.fc[0].weight)
nn.init.constant_(model.fc[0].bias, 0)
nn.init.xavier_normal_(model.fc[2].weight)
nn.init.constant_(model.fc[2].bias, 0)

# 获取新的模型 state_dict
new_model_dict = model.state_dict()

# 可以选择性地将新的 state_dict 保存下来
# torch.save(new_model_dict, 'modified_model.pth')
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小酒馆燃着灯

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值