第J2周:ResNet50V2算法实战与解析(pytorch版)

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

📌本周任务:📌
– 1.请根据本文TensorFlow代码,编写出相应的Pytorch代码(建议使用上周的数据测试一下模型是否构建正确)
– 2.了解ResNetV2与ResNetV的区别
– 3.改进思路是否可以迁移到其他地方呢(自由探索)

🏡 我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Notebook
  • 深度学习环境:Pytorch
    • torch==2.3.1+cu118
    • torchvision==0.18.1+cu118

       本文完全根据 第J2周:ResNet50V2算法实战与解析(TensorFlow版)中的内容转换为pytorch版本,所以前述性的内容不在一一重复,仅就pytorch版本中的内容进行叙述。

一、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

import warnings
warnings.filterwarnings("ignore") #忽略警告信息

import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device

运行结果:

device(type='cuda')

2. 导入数据

import pathlib

data_dir=r"D:\THE MNIST DATABASE\J-series\J1\bird_photos"
data_dir=pathlib.Path(data_dir)

查看数据集中图片的数量

image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

运行结果:

图片总数为: 565

3. 查看数据集分类

data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[5] for path in data_paths]
classNames

运行结果:

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

4. 随机查看图片

随机抽取数据集中的20张图片进行查看

import PIL,random
import matplotlib.pyplot as plt
from PIL import Image

data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.axis("off")
    image=random.choice(data_paths2) #随机选择一个图片
    plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名
    plt.imshow(Image.open(str(image))) #显示图片

运行结果:

5. 图片预处理   

import torchvision.transforms as transforms
from torchvision import transforms,datasets

train_transforms=transforms.Compose([
    transforms.Resize([224,224]), #将图片统一尺寸
    transforms.RandomHorizontalFlip(), #将图片随机水平翻转
    transforms.RandomRotation(0.2), #将图片按照 0.2 的弧度值随机旋转
    transforms.ToTensor(), #将图片转换为tensor
    transforms.Normalize(  #标准化处理-->转换为正态分布,使模型更容易收敛
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

total_data=datasets.ImageFolder(
    r"D:\THE MNIST DATABASE\J-series\J1\bird_photos",
    transform=train_transforms
)
total_data

运行结果:

Dataset ImageFolder
    Number of datapoints: 565
    Root location: D:\THE MNIST DATABASE\J-series\J1\bird_photos
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

将数据集分类情况进行映射输出:

total_data.class_to_idx

运行结果:

{'Bananaquit': 0,
 'Black Skimmer': 1,
 'Black Throated Bushtiti': 2,
 'Cockatoo': 3}

6. 划分数据集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size

train_dataset,test_dataset=torch.utils.data.random_split(
    total_data,
    [train_size,test_size]
)
train_dataset,test_dataset

运行结果:

(<torch.utils.data.dataset.Subset at 0x1fc1cb43e90>,
 <torch.utils.data.dataset.Subset at 0x1fc1cb43f10>)

查看训练集和测试集的数据数量:

train_size,test_size

运行结果:

(452, 113)

7. 加载数据集

batch_size=16
train_dl=torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)
test_dl=torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)

查看测试集的情况:

for x,y in train_dl:
    print("Shape of x [N,C,H,W]:",x.shape)
    print("Shape of y:",y.shape,y.dtype)
    break

运行结果:

Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64

二、手动搭建ResNet50V2模型

1.Residual Block

import torch.nn as nn

#Residual Block
class Block2(nn.Module):
    def __init__(self,in_channel,filters,kernel_size=3,stride=1,conv_shortcut=False):
        super(Block2,self).__init__()
        self.preact=nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU()
        )
        
        self.shortcut=conv_shortcut
        
        if self.shortcut:
            self.short=nn.Conv2d(in_channel,4*filters,1,stride=stride,padding=0,bias=False)
        elif stride>1:
            self.short=nn.MaxPool2d(kernel_size=1,stride=stride,padding=0)
        else:
            self.short=nn.Identity()
            
        self.conv1=nn.Sequential(
            nn.Conv2d(in_channel,filters,1,stride=1,bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU()
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(filters,filters,kernel_size,stride=stride,padding=1,bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU()
        )
        self.conv3=nn.Conv2d(filters,4*filters,1,stride=1,bias=False)
        
    def forward(self,x):
        x1=self.preact(x)
        if self.shortcut:
            x2=self.short(x1)
        else:
            x2=self.short(x)
        x1=self.conv1(x1)
        x1=self.conv2(x1)
        x1=self.conv3(x1)
        x=x1+x2
        return x

2.堆叠Residual Block

class Stack2(nn.Module):
    def __init__(self,in_channel,filters,blocks,stride=2):
        super(Stack2,self).__init__()
        self.conv=nn.Sequential()
        self.conv.add_module(str(0),Block2(in_channel,filters,conv_shortcut=True))
        for i in range(1,blocks-1):
            self.conv.add_module(str(i),Block2(4*filters,filters))
        self.conv.add_module(str(blocks-1),Block2(4*filters,filters,stride=stride))
        
    def forward(self,x):
        x=self.conv(x)
        return x

3.ResNet50V2架构复现

#构建ResNet50V2
class ResNet50V2(nn.Module):
    def __init__(
        self,
        include_top=True,  #是否包含位于网络顶部的全连接层
        preact=True,  #是否使用预激活
        use_bias=True,  #是否对卷积层使用偏置
        input_shape=[224,224,3],  #
        classes=1000,
        pooling=None   #用于分类图像的可选类数
    ):
        super(ResNet50V2,self).__init__()
        
        self.conv1=nn.Sequential()
        self.conv1.add_module('conv',
                              nn.Conv2d(3,64,7,stride=2,padding=3,bias=use_bias,padding_mode='zeros'))
        if not preact:
            self.conv1.add_module('bn',nn.BatchNorm2d(64))
            self.conv1.add_module('relu',nn.ReLU())
        self.conv1.add_module('max_pool',
                              nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
        
        self.conv2=Stack2(64,64,3)
        self.conv3=Stack2(256,128,4)
        self.conv4=Stack2(512,256,6)
        self.conv5=Stack2(1024,512,3,stride=1)
        
        self.post=nn.Sequential()
        if preact:
            self.post.add_module('bn',nn.BatchNorm2d(2048))
            self.post.add_module('relu',nn.ReLU())
        if include_top:
            self.post.add_module('avg_pool',nn.AdaptiveAvgPool2d((1,1)))
            self.post.add_module('flatten',nn.Flatten())
            self.post.add_module('fc',nn.Linear(2048,4))
        else:
            if pooling=='avg':
                self.post.add_module('avg_pool',nn.AdaptiveAvgPool2d((1,1)))
            elif pooling=='max':
                self.post.add_module('max_pool',nn.AdaptiveAvgPool2d((1,1)))
                
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=self.conv4(x)
        x=self.conv5(x)
        x=self.post(x)
        return x

4.查看模型结构

model=ResNet50V2().to(device)
model

运行结果:

ResNet50V2(
  (conv1): Sequential(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2): Stack2(
    (conv): Sequential(
      (0): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (2): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
        (conv1): Sequential(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (conv3): Stack2(
    (conv): Sequential(
      (0): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv1): Sequential(
          (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (2): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (3): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
        (conv1): Sequential(
          (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (conv4): Stack2(
    (conv): Sequential(
      (0): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv1): Sequential(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (2): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (3): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (4): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (5): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
        (conv1): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (conv5): Stack2(
    (conv): Sequential(
      (0): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv1): Sequential(
          (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (2): Block2(
        (preact): Sequential(
          (0): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
        )
        (short): Identity()
        (conv1): Sequential(
          (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (post): Sequential(
    (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (avg_pool): AdaptiveAvgPool2d(output_size=(1, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc): Linear(in_features=2048, out_features=4, bias=True)
  )
)

5.网络结构打印

#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))

运行结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,472
         MaxPool2d-2           [-1, 64, 56, 56]               0
       BatchNorm2d-3           [-1, 64, 56, 56]             128
              ReLU-4           [-1, 64, 56, 56]               0
            Conv2d-5          [-1, 256, 56, 56]          16,384
            Conv2d-6           [-1, 64, 56, 56]           4,096
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]          36,864
      BatchNorm2d-10           [-1, 64, 56, 56]             128
             ReLU-11           [-1, 64, 56, 56]               0
           Conv2d-12          [-1, 256, 56, 56]          16,384
           Block2-13          [-1, 256, 56, 56]               0
      BatchNorm2d-14          [-1, 256, 56, 56]             512
             ReLU-15          [-1, 256, 56, 56]               0
         Identity-16          [-1, 256, 56, 56]               0
           Conv2d-17           [-1, 64, 56, 56]          16,384
      BatchNorm2d-18           [-1, 64, 56, 56]             128
             ReLU-19           [-1, 64, 56, 56]               0
           Conv2d-20           [-1, 64, 56, 56]          36,864
      BatchNorm2d-21           [-1, 64, 56, 56]             128
             ReLU-22           [-1, 64, 56, 56]               0
           Conv2d-23          [-1, 256, 56, 56]          16,384
           Block2-24          [-1, 256, 56, 56]               0
      BatchNorm2d-25          [-1, 256, 56, 56]             512
             ReLU-26          [-1, 256, 56, 56]               0
        MaxPool2d-27          [-1, 256, 28, 28]               0
           Conv2d-28           [-1, 64, 56, 56]          16,384
      BatchNorm2d-29           [-1, 64, 56, 56]             128
             ReLU-30           [-1, 64, 56, 56]               0
           Conv2d-31           [-1, 64, 28, 28]          36,864
      BatchNorm2d-32           [-1, 64, 28, 28]             128
             ReLU-33           [-1, 64, 28, 28]               0
           Conv2d-34          [-1, 256, 28, 28]          16,384
           Block2-35          [-1, 256, 28, 28]               0
           Stack2-36          [-1, 256, 28, 28]               0
      BatchNorm2d-37          [-1, 256, 28, 28]             512
             ReLU-38          [-1, 256, 28, 28]               0
           Conv2d-39          [-1, 512, 28, 28]         131,072
           Conv2d-40          [-1, 128, 28, 28]          32,768
      BatchNorm2d-41          [-1, 128, 28, 28]             256
             ReLU-42          [-1, 128, 28, 28]               0
           Conv2d-43          [-1, 128, 28, 28]         147,456
      BatchNorm2d-44          [-1, 128, 28, 28]             256
             ReLU-45          [-1, 128, 28, 28]               0
           Conv2d-46          [-1, 512, 28, 28]          65,536
           Block2-47          [-1, 512, 28, 28]               0
      BatchNorm2d-48          [-1, 512, 28, 28]           1,024
             ReLU-49          [-1, 512, 28, 28]               0
         Identity-50          [-1, 512, 28, 28]               0
           Conv2d-51          [-1, 128, 28, 28]          65,536
      BatchNorm2d-52          [-1, 128, 28, 28]             256
             ReLU-53          [-1, 128, 28, 28]               0
           Conv2d-54          [-1, 128, 28, 28]         147,456
      BatchNorm2d-55          [-1, 128, 28, 28]             256
             ReLU-56          [-1, 128, 28, 28]               0
           Conv2d-57          [-1, 512, 28, 28]          65,536
           Block2-58          [-1, 512, 28, 28]               0
      BatchNorm2d-59          [-1, 512, 28, 28]           1,024
             ReLU-60          [-1, 512, 28, 28]               0
         Identity-61          [-1, 512, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]          65,536
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65          [-1, 128, 28, 28]         147,456
      BatchNorm2d-66          [-1, 128, 28, 28]             256
             ReLU-67          [-1, 128, 28, 28]               0
           Conv2d-68          [-1, 512, 28, 28]          65,536
           Block2-69          [-1, 512, 28, 28]               0
      BatchNorm2d-70          [-1, 512, 28, 28]           1,024
             ReLU-71          [-1, 512, 28, 28]               0
        MaxPool2d-72          [-1, 512, 14, 14]               0
           Conv2d-73          [-1, 128, 28, 28]          65,536
      BatchNorm2d-74          [-1, 128, 28, 28]             256
             ReLU-75          [-1, 128, 28, 28]               0
           Conv2d-76          [-1, 128, 14, 14]         147,456
      BatchNorm2d-77          [-1, 128, 14, 14]             256
             ReLU-78          [-1, 128, 14, 14]               0
           Conv2d-79          [-1, 512, 14, 14]          65,536
           Block2-80          [-1, 512, 14, 14]               0
           Stack2-81          [-1, 512, 14, 14]               0
      BatchNorm2d-82          [-1, 512, 14, 14]           1,024
             ReLU-83          [-1, 512, 14, 14]               0
           Conv2d-84         [-1, 1024, 14, 14]         524,288
           Conv2d-85          [-1, 256, 14, 14]         131,072
      BatchNorm2d-86          [-1, 256, 14, 14]             512
             ReLU-87          [-1, 256, 14, 14]               0
           Conv2d-88          [-1, 256, 14, 14]         589,824
      BatchNorm2d-89          [-1, 256, 14, 14]             512
             ReLU-90          [-1, 256, 14, 14]               0
           Conv2d-91         [-1, 1024, 14, 14]         262,144
           Block2-92         [-1, 1024, 14, 14]               0
      BatchNorm2d-93         [-1, 1024, 14, 14]           2,048
             ReLU-94         [-1, 1024, 14, 14]               0
         Identity-95         [-1, 1024, 14, 14]               0
           Conv2d-96          [-1, 256, 14, 14]         262,144
      BatchNorm2d-97          [-1, 256, 14, 14]             512
             ReLU-98          [-1, 256, 14, 14]               0
           Conv2d-99          [-1, 256, 14, 14]         589,824
     BatchNorm2d-100          [-1, 256, 14, 14]             512
            ReLU-101          [-1, 256, 14, 14]               0
          Conv2d-102         [-1, 1024, 14, 14]         262,144
          Block2-103         [-1, 1024, 14, 14]               0
     BatchNorm2d-104         [-1, 1024, 14, 14]           2,048
            ReLU-105         [-1, 1024, 14, 14]               0
        Identity-106         [-1, 1024, 14, 14]               0
          Conv2d-107          [-1, 256, 14, 14]         262,144
     BatchNorm2d-108          [-1, 256, 14, 14]             512
            ReLU-109          [-1, 256, 14, 14]               0
          Conv2d-110          [-1, 256, 14, 14]         589,824
     BatchNorm2d-111          [-1, 256, 14, 14]             512
            ReLU-112          [-1, 256, 14, 14]               0
          Conv2d-113         [-1, 1024, 14, 14]         262,144
          Block2-114         [-1, 1024, 14, 14]               0
     BatchNorm2d-115         [-1, 1024, 14, 14]           2,048
            ReLU-116         [-1, 1024, 14, 14]               0
        Identity-117         [-1, 1024, 14, 14]               0
          Conv2d-118          [-1, 256, 14, 14]         262,144
     BatchNorm2d-119          [-1, 256, 14, 14]             512
            ReLU-120          [-1, 256, 14, 14]               0
          Conv2d-121          [-1, 256, 14, 14]         589,824
     BatchNorm2d-122          [-1, 256, 14, 14]             512
            ReLU-123          [-1, 256, 14, 14]               0
          Conv2d-124         [-1, 1024, 14, 14]         262,144
          Block2-125         [-1, 1024, 14, 14]               0
     BatchNorm2d-126         [-1, 1024, 14, 14]           2,048
            ReLU-127         [-1, 1024, 14, 14]               0
        Identity-128         [-1, 1024, 14, 14]               0
          Conv2d-129          [-1, 256, 14, 14]         262,144
     BatchNorm2d-130          [-1, 256, 14, 14]             512
            ReLU-131          [-1, 256, 14, 14]               0
          Conv2d-132          [-1, 256, 14, 14]         589,824
     BatchNorm2d-133          [-1, 256, 14, 14]             512
            ReLU-134          [-1, 256, 14, 14]               0
          Conv2d-135         [-1, 1024, 14, 14]         262,144
          Block2-136         [-1, 1024, 14, 14]               0
     BatchNorm2d-137         [-1, 1024, 14, 14]           2,048
            ReLU-138         [-1, 1024, 14, 14]               0
       MaxPool2d-139           [-1, 1024, 7, 7]               0
          Conv2d-140          [-1, 256, 14, 14]         262,144
     BatchNorm2d-141          [-1, 256, 14, 14]             512
            ReLU-142          [-1, 256, 14, 14]               0
          Conv2d-143            [-1, 256, 7, 7]         589,824
     BatchNorm2d-144            [-1, 256, 7, 7]             512
            ReLU-145            [-1, 256, 7, 7]               0
          Conv2d-146           [-1, 1024, 7, 7]         262,144
          Block2-147           [-1, 1024, 7, 7]               0
          Stack2-148           [-1, 1024, 7, 7]               0
     BatchNorm2d-149           [-1, 1024, 7, 7]           2,048
            ReLU-150           [-1, 1024, 7, 7]               0
          Conv2d-151           [-1, 2048, 7, 7]       2,097,152
          Conv2d-152            [-1, 512, 7, 7]         524,288
     BatchNorm2d-153            [-1, 512, 7, 7]           1,024
            ReLU-154            [-1, 512, 7, 7]               0
          Conv2d-155            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-156            [-1, 512, 7, 7]           1,024
            ReLU-157            [-1, 512, 7, 7]               0
          Conv2d-158           [-1, 2048, 7, 7]       1,048,576
          Block2-159           [-1, 2048, 7, 7]               0
     BatchNorm2d-160           [-1, 2048, 7, 7]           4,096
            ReLU-161           [-1, 2048, 7, 7]               0
        Identity-162           [-1, 2048, 7, 7]               0
          Conv2d-163            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-164            [-1, 512, 7, 7]           1,024
            ReLU-165            [-1, 512, 7, 7]               0
          Conv2d-166            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-167            [-1, 512, 7, 7]           1,024
            ReLU-168            [-1, 512, 7, 7]               0
          Conv2d-169           [-1, 2048, 7, 7]       1,048,576
          Block2-170           [-1, 2048, 7, 7]               0
     BatchNorm2d-171           [-1, 2048, 7, 7]           4,096
            ReLU-172           [-1, 2048, 7, 7]               0
        Identity-173           [-1, 2048, 7, 7]               0
          Conv2d-174            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-175            [-1, 512, 7, 7]           1,024
            ReLU-176            [-1, 512, 7, 7]               0
          Conv2d-177            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-178            [-1, 512, 7, 7]           1,024
            ReLU-179            [-1, 512, 7, 7]               0
          Conv2d-180           [-1, 2048, 7, 7]       1,048,576
          Block2-181           [-1, 2048, 7, 7]               0
          Stack2-182           [-1, 2048, 7, 7]               0
     BatchNorm2d-183           [-1, 2048, 7, 7]           4,096
            ReLU-184           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-185           [-1, 2048, 1, 1]               0
         Flatten-186                 [-1, 2048]               0
          Linear-187                    [-1, 4]           8,196
================================================================
Total params: 23,508,612
Trainable params: 23,508,612
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 241.68
Params size (MB): 89.68
Estimated Total Size (MB): 331.93
----------------------------------------------------------------

三、 训练模型

1. 编写训练函数

#训练函数
def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset) #训练集的大小
    num_batches=len(dataloader)  #批次数目
    
    train_loss,train_acc=0,0 #初始化训练损失和正确率
    
    for x,y in dataloader: #获取图片及其标签
        x,y=x.to(device),y.to(device)
        
        #计算预测误差
        pred=model(x) #网络输出
        loss=loss_fn(pred,y)  #计算网络输出和真实值之间的差距,二者差值即为损失
        
        #反向传播
        optimizer.zero_grad() #grad属性归零
        loss.backward() #反向传播
        optimizer.step() #每一步自动更新
        
        #记录acc和loss
        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
    
    train_acc/=size
    train_loss/=num_batches
    
    return train_acc,train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

#测试函数
def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset) #测试集的大小
    num_batches=len(dataloader) #批次数目
    test_loss,test_acc=0,0
    
    #当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target=imgs.to(device),target.to(device)
            
            #计算loss
            target_pred=model(imgs)
            loss=loss_fn(target_pred,target)
            
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
            
    test_acc/=size
    test_loss/=num_batches
    
    return test_acc,test_loss

3. 正式训练

import copy
opt=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss()  #创建损失函数

epochs=10

train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]

best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)
    
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    
    #保存最佳模型到J2_model
    if epoch_test_acc>best_acc:
        best_acc=epoch_test_acc
        J2_model=copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    #获取当前学习率
    lr=opt.state_dict()['param_groups'][0]['lr']
    
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
                          epoch_test_acc*100,epoch_test_loss,lr))
#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J2_model.pth'
torch.save(model.state_dict(),PATH)          

运行结果:

Epoch: 1,Train_acc:75.2%,Train_loss:0.674,Test_acc:70.8%,Test_loss:1.028,Lr:1.00E-04
Epoch: 2,Train_acc:76.8%,Train_loss:0.629,Test_acc:75.2%,Test_loss:0.868,Lr:1.00E-04
Epoch: 3,Train_acc:86.9%,Train_loss:0.398,Test_acc:79.6%,Test_loss:1.163,Lr:1.00E-04
Epoch: 4,Train_acc:91.4%,Train_loss:0.290,Test_acc:74.3%,Test_loss:0.954,Lr:1.00E-04
Epoch: 5,Train_acc:90.0%,Train_loss:0.296,Test_acc:76.1%,Test_loss:0.877,Lr:1.00E-04
Epoch: 6,Train_acc:88.1%,Train_loss:0.323,Test_acc:58.4%,Test_loss:1.141,Lr:1.00E-04
Epoch: 7,Train_acc:92.3%,Train_loss:0.243,Test_acc:74.3%,Test_loss:0.770,Lr:1.00E-04
Epoch: 8,Train_acc:94.5%,Train_loss:0.179,Test_acc:77.9%,Test_loss:1.187,Lr:1.00E-04
Epoch: 9,Train_acc:95.8%,Train_loss:0.139,Test_acc:77.9%,Test_loss:0.919,Lr:1.00E-04
Epoch:10,Train_acc:94.9%,Train_loss:0.159,Test_acc:84.1%,Test_loss:0.508,Lr:1.00E-04

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']   #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False   #正常显示负号
plt.rcParams['figure.dpi']=300   #分辨率
 
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
 
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

2. 指定图片进行预测  

from PIL import Image
 
classes=list(total_data.class_to_idx)
 
def predict_one_image(image_path,model,transform,classes):
    
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)   #展示预测的图片
    
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    
    model.eval()
    output=model(img)
    
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print(f'预测结果是:{pred_class}')

 预测图片:

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos\Black Skimmer\001.jpg',
                  model=model,transform=train_transforms,classes=classes)

运行结果:

预测结果是:Black Skimmer

3. 模型评估

J2_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J2_model,loss_fn)
epoch_test_acc,epoch_test_loss

运行结果:

(0.8495575221238938, 0.5142213940271176)

五、心得体会

本周项目训练中,在pytorch环境下手动搭建了ResNet50V2模型,与ResNet50模型相比,残差模型将BN和ReLU进行了前置,在一定程度上有效地提升了模型的准确率。但是在本项目中,模型结果不尽人意,虽然也对数据集进行了增强处理,可能是由于数据过小引起的,留待今后验证。

  • 16
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值