Pytorch-day09-模型微调-checkpoint

模型微调(fine-tune)-迁移学习

  • torchvision微调
  • timm微调
  • 半精度训练

起源:

  • 1、随着深度学习的发展,模型的参数越来越大,许多开源模型都是在较大数据集上进行训练的,比如Imagenet-1k,Imagenet-11k等
  • 2、如果数据集可能只有几千张,训练几千万参数的大模型,过拟合无法避免
  • 3、如果我们想从零开始训练一个大模型,那么我们的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资⾦,成本无法承受

解决方案:

  • 应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上
  • 比如:ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成
  • 模型微调(finetune):就是先找到一个同类的别人训练好的模型,基于已经训练好的模型换成自己的数据,通过训练调整一下参数
不同数据集下使用微调:
  • 数据集1 - 数据量少,但数据相似度非常高 - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

  • 数据集2 - 数据量少,数据相似度低 - 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

  • 数据集3 - 数据量大,数据相似度低 - 在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)

  • 数据集4 - 数据量大,数据相似度高 - 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

微调的是什么?
  • 换数据源
  • 针对K层进行重新训练
  • K层的权重&shape调整

1、模型微调(fine-tune)一般流程:

  • 1、在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型
  • 2、创建一个新的神经网络模型,即目标模型,它复制了源模型上除了输出层外的所有模型设计及其参数
  • 3、为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数
  • 4、在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x43rJPAE-1692785269018)(attachment:image.png)]

2、torchvision微调

2.1 实例化Model

import torchvision.models as models
resnet34 = models.resnet34(pretrained=True)

pretrained参数说明:

  • 1、通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重
  • 2、当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重

注意:如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,否则会报错。

2.2 训练特定层

如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

2.3 实例

  • 使用resnet34为例的将1000类改为10类,但是仅改变最后一层的模型参数
  • 我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.models as models
from torchinfo import summary
#超参数定义
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 2
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets

#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                   ])


train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)

#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset, 
                                           batch_size=batch_size, num_workers=4, 
                                           shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(test_cifar_dataset, 
                                         batch_size=batch_size, num_workers=4, 
                                         shuffle=False)


# 下载预训练模型 restnet50
resnet34 = models.resnet34(pretrained=True)
print(resnet34)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\xulele/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:10<00:00, 8.57MB/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
#查看模型结构
summary(resnet34, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 21,797,672
Trainable params: 21,797,672
Non-trainable params: 0
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.82
Params size (MB): 87.19
Estimated Total Size (MB): 147.61
==========================================================================================
#检测 模型准确率
def cal_predict_correct(model):
    test_total_correct = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
    
        outputs = model(images)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
#     print("test_total_correct: "+ str(test_total_correct))
    return test_total_correct
total_correct = cal_predict_correct(resnet34)
print("test_total_correct: "+ str(test_total_correct / 10000))
test_total_correct: 0.1
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

# 冻结参数的梯度
feature_extract = True
new_model = resnet34
set_parameter_requires_grad(new_model, feature_extract)

# 修改模型
#训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层
num_ftrs = new_model.fc.in_features
new_model.fc = nn.Linear(in_features=num_ftrs, out_features=10, bias=True)


summary(new_model, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 10]                   --
├─Conv2d: 1-1                            [1, 64, 112, 112]         (9,408)
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         (128)
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          (73,728)
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          (256)
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          (8,448)
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          (294,912)
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          (512)
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          (33,280)
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            (1,179,648)
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            (1,024)
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            (132,096)
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 10]                   5,130
==========================================================================================
Total params: 21,289,802
Trainable params: 5,130
Non-trainable params: 21,284,672
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.81
Params size (MB): 85.16
Estimated Total Size (MB): 145.57
==========================================================================================
#训练&验证
Resnet34_new = new_model.to(device)
# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数:自定义损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs

total_step = len(train_loader)
train_all_loss = []
test_all_loss = []

for i in range(epoch):
    Resnet34_new.train()
    train_total_loss = 0
    train_total_num = 0
    train_total_correct = 0

    for iter, (images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        train_total_correct += (outputs.argmax(1) == labels).sum().item()
        
        #backword
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
    
    Resnet34_new.eval()
    test_total_loss = 0
    test_total_correct = 0
    test_total_num = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
    print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
        i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100
    
    ))
    train_all_loss.append(np.round(train_total_loss / train_total_num,4))
    test_all_loss.append(np.round(test_total_loss / test_total_num,4))

Epoch [1/2], Iter [1481/3125], train_loss:0.17220

半精度训练

问题:

GPU的性能主要分为两部分:算力和显存。
前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算
在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率
定义:

PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性
但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”
image.png

显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算

3.1、半精度训练的设置
1、引入 from torch.cuda.amp import autocast
2、forward函数指定 autocast 装饰器
3、训练过程: 只需在将数据输入模型及其之后的部分放入“with autocast():“
4、半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)

引入


from torch.cuda.amp import autocast

# forward指定装饰器
@autocast()   
def forward(self, x):
    ...
    return x

# 指定with autocast 
 for x in train_loader:
    x = x.cuda()
    with autocast():
            output = model(x)
        ...

半精度训练案例
from torch.cuda.amp import autocast

半精度模型

class DemoModel(nn.Module):
def init(self):
super(DemoModel, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

@autocast() 
def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x





#训练&验证

device = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)
half_model = DemoModel().to(device)

损失函数:自定义损失函数

criterion = nn.CrossEntropyLoss()

优化器

optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs

total_step = len(train_loader)
train_all_loss = []
test_all_loss = []

for i in range(epoch):
half_model.train()
train_total_loss = 0
train_total_num = 0
train_total_correct = 0

for iter, (images,labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
with autocast():
outputs = half_model(images)
loss = criterion(outputs,labels)
train_total_correct += (outputs.argmax(1) == labels).sum().item()

#backword
optimizer.zero_grad()
loss.backward()
optimizer.step()

        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))

half_model.eval()
test_total_loss = 0
test_total_correct = 0
test_total_num = 0
for iter,(images,labels) in enumerate(test_loader):
    images = images.to(device)
    labels = labels.to(device)
    with autocast():
        outputs = half_model(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
        print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
            i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100

))
train_all_loss.append(np.round(train_total_loss / train_total_num,4))
test_all_loss.append(np.round(test_total_loss / test_total_num,4))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
pytorch-09.ipynb是一个使用PyTorch库进行深度学习实践的笔记本文件。PyTorch是一个基于Python的深度学习框架,它提供了方便简洁的API接口,使得深度学习模型的构建和训练变得更加容易。 在这个笔记本文件中,我推测可能包括以下内容: 1. 张量的基本概念和操作:张量是PyTorch中最基本的数据类型,类似于Numpy中的多维数组。这个笔记本可能会介绍如何创建和操作张量,以及张量在深度学习中的应用。 2. 自动梯度计算:PyTorch通过自动梯度计算(Autograd)模块实现了计算图和反向传播。这个笔记本可能会介绍如何使用PyTorch的autograd模块来计算张量的导数,并利用导数进行模型参数的更新。 3. 模型构建和训练:深度学习模型的构建和训练是PyTorch的核心功能。这个笔记本可能会介绍如何使用PyTorch构建各种类型的神经网络模型(如全连接网络、卷积神经网络和循环神经网络)并进行训练。 4. 数据加载和预处理:在深度学习中,数据的加载和预处理是非常重要的一步。这个笔记本可能会介绍如何使用PyTorch的数据加载器和数据转换工具进行数据的加载和处理。 5. 模型性能评估和调优:在实际应用中,评估模型性能和进行调优是不可或缺的步骤。这个笔记本可能会介绍如何使用PyTorch进行模型性能的评估,并介绍一些常见的调优方法,如学习率调整、正则化和dropout等。 总之,这个笔记本文件可能会提供一些关于PyTorch库的基本操作和深度学习模型构建的实践指南,帮助读者更好地理解和应用PyTorch进行深度学习任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

闪闪发亮的小星星

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值