基于torchvision对模型最后几层进行微调,用于训练自己的数据

简介

对模型进行微调训练是常见的模型训练策略,首先我们要查看自己的torchvision版本。

为什么要查看torchvision版本呢?因为不同版本的torchvisionmodels模块中含有不同的网络架构。

例如0.11.0torchvision.models中提供了以下模型:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)

使用

接下来我们进行基于torchvision模型最后几层进行微调,用于训练自己的数据。

导入库

import torch.nn as nn
from torchvision import models

加载模型并查看

model = models.resnet50(pretrained=True, progress=False)
print(model)

其打印的模型结构非常长,我们只需要看最后一个部分即可:

(2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

在这里我们需要修改fc层的out_features为你需要指定的类别数。

如猫狗分类任务中,out_features就应该等于2.

修改模型的out_features

num_ftrs = model.fc.in_features
class_num = 5 # 你的类别数量,比如猫狗识别,就是2
model.fc = nn.Linear(num_ftrs, class_num)
print(model)

最后我们打印出来修改的网络结构:

(2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=5, bias=True)
)

可以看到已经是进行修改了,接下来就可以进行训练了。

其他模型的修改实战

接下来,我们加载efficientnet_b0对其网络结构进行修改:

model = models.efficientnet_b0(pretrained=True, progress=False)
print(model)
'''
(8): ConvNormActivation(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=True)
    (1): Linear(in_features=1280, out_features=1000, bias=True)
  )
)
'''
num_ftrs = model.classifier[1].in_features
class_num = 5 # 你的类别数量,比如猫狗识别,就是2
model.classifier[1] = nn.Linear(num_ftrs, class_num)
print(model)
'''
(8): ConvNormActivation(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=True)
    (1): Linear(in_features=1280, out_features=5, bias=True)
  )
  (fc): Linear(in_features=1280, out_features=5, bias=True)
)
'''
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch EVA02模型微调是指在已经训练好的EVA02模型基础上,通过对新的数据进行训练,以适应特定任务或数据集的需求。微调可以帮助我们利用预训练模型的知识和参数,加速模型训练过程,并提高模型在新任务上的性能。 下面是PyTorch EVA02模型微调的一般步骤: 1. 加载预训练模型:首先,你需要下载并加载EVA02模型的预训练权重。PyTorch提供了方便的接口来加载预训练模型,例如使用`torchvision.models`中的`resnet`模块。 2. 修改模型结构:根据你的任务需求,你可能需要修改EVA02模型最后几层或全连接层。通常情况下,你需要将最后一层的输出节点数修改为你任务中的类别数。 3. 冻结部分参数:为了保留预训练模型的知识,你可以选择冻结部分参数,即不对它们进行更新。一般来说,冻结预训练模型的前几层或者全部卷积层是常见的做法。 4. 定义损失函数和优化器:根据你的任务类型,选择适当的损失函数和优化器。常见的损失函数包括交叉熵损失函数、均方误差损失函数等,常见的优化器包括随机梯度下降(SGD)、Adam等。 5. 训练模型:使用新的数据集对模型进行训练。你可以通过迭代数据集的方式,将数据输入模型,计算损失并进行反向传播更新模型参数。 6. 评估模型性能:在训练过程中,你可以使用验证集来评估模型在新任务上的性能。常见的评估指标包括准确率、精确率、召回率等。 7. 微调参数:如果模型在新任务上的性能不理想,你可以微调部分参数,即解冻之前冻结的层,并继续训练模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落难Coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值