P9周:YOLOv5-Backbone模块实现

环境配置:

Python version:  3.8.17 (default, Jul  5 2023, 20:44:21) [MSC v.1916 64 bit (AMD64)]
Pytorch version:  2.0.1+cu117
Torchvision version:  0.15.2+cu117
CUDA is available: True
Using device: cuda

本次天气数据集由K同学提供,若有需要请联系K同学

一、前期准备

1.导入所需要的包并设置gpu

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

import pathlib, random, copy
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

结果打印:

cuda

2.数据预处理及可视化

将下载好的数据集放在同一文件夹下,设置好路径,读取所有图片,随机打印出五张图片信息,并随机打印出二十张图片。

image_path = 'weather_photos'

# 所有图像的列表
image_list = list(pathlib.Path(image_path).glob('*/*'))

# 随机打印几张图像的信息
for _ in range(5):
    image = random.choice(image_list)
    print(f"{str(image)}, shape is: {np.array(Image.open(str(image))).shape}")

# 查看随机的20张图像
plt.figure(figsize=(20, 4))
for i in range(20):
    plt.subplot(2, 10, i + 1)
    plt.axis('off')
    image = random.choice(image_list)
    plt.title(image.parts[-2])
    plt.imshow(Image.open(str(image)))

展示如下:

weather_photos\cloudy\cloudy167.jpg, shape is: (167, 222, 3)
weather_photos\cloudy\cloudy290.jpg, shape is: (185, 298, 3)
weather_photos\sunrise\sunrise324.jpg, shape is: (168, 252, 3)
weather_photos\sunrise\sunrise67.jpg, shape is: (194, 259, 3)
weather_photos\rain\rain142.jpg, shape is: (420, 640, 3)

在这里插入图片描述
进行预处理,打印出类名

img_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
dataset = datasets.ImageFolder(image_path, transform=img_transform)
class_names = [k for k in dataset.class_to_idx]
print(class_names)

类别名:

[‘cloudy’, ‘rain’, ‘shine’, ‘sunrise’]

按比例划分训练集与测试集,设置batch_size大小,定义测试集与训练集。

train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

二、模型搭建及可视化

1.相关知识补充

YOLOv5是一种轻量级的目标检测算法,其backbone是指用于特征提取的主干网络。YOLOv5的backbone采用了CSPDarknet53,其中CSP表示Cross Stage Partial网络,Darknet是一个轻量级的卷积神经网络。

CSPDarknet53是YOLOv5中的关键组件,它是对Darknet53网络的改进。CSPDarknet53通过引入CSP结构和SPP结构来提高网络的性能和效率。

CSP结构是指将输入特征图分成两个分支,一个分支直接进行卷积操作,另一个分支进行通道分割和再连接操作,从而减少了计算量和内存占用。这种结构可以有效提升网络在目标检测任务中的表达能力。

SPP结构是指空间金字塔池化(Spatial Pyramid Pooling),它可以在不同尺度上提取特征,从而使网络能够检测不同大小的目标。这种结构可以提高网络的感受野,提升目标检测的准确性。

YOLOv5的backbone的架构可以概括为以下几个步骤:

1.输入图像经过一系列卷积操作,提取低级特征。
2.通过CSP结构将低级特征分成两个分支,其中一个分支直接进行卷积操作,另一个分支进行通道分割和再连接操作。
3.分支输出的特征图经过多次残差结构和CSP结构,逐渐提取更高级的特征。
4.最后通过SPP结构进行空间金字塔池化,将不同尺度的特征进行融合。
5.最终得到的特征图通过卷积层和全连接层进行目标检测的预测。
这里使用K同学提供的架构图来帮助理解。
在这里插入图片描述
YOLOv5由以下几个主要组件构成:

Backbone:YOLOv5使用一个强大的骨干网络作为其主干架构,用于提取图像特征。YOLOv5的默认主干网络是CSPDarknet53,它是一个深层的卷积神经网络,具有较强的特征提取能力。

Neck:YOLOv5使用了一种称为PANet的特征金字塔网络结构,用于融合不同尺度的特征图。PANet通过上下文感知模块和特征金字塔模块来增强特征图的表示能力,以便更好地检测不同大小的目标。

Head:YOLOv5的头部结构包含了一系列卷积层和全连接层,用于生成目标检测的预测结果。头部结构负责将特征图转化为边界框的位置和类别预测。

Loss Function:YOLOv5使用一种称为YOLOv3-loss的多任务损失函数,用于训练模型。该损失函数包括了边界框位置损失、目标置信度损失和类别预测损失,以最小化目标检测的误差。

2.代码及模型数据打印

class Conv(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride=1, padding=None, groups=1, activation=True):
        super().__init__()

        self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, pad(kernel_size, padding), groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(ch_out)
        self.act = nn.SiLU() if activation is True else (
            activation if isinstance(activation, nn.Module) else nn.Identity())

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        return x


class Bottleneck(nn.Module):
    def __init__(self, ch_in, ch_out, shortcut=True, groups=1, factor=0.5):
        super().__init__()

        hidden_size = int(ch_out * factor)
        self.conv1 = Conv(ch_in, hidden_size, 1)
        self.conv2 = Conv(hidden_size, ch_out, 3)
        self.add = shortcut and ch_in == ch_out

    def forward(self, x):
        return x + self.conv2(self.conv1(x)) if self.add else self.conv2(self.conv1(x))


class C3(nn.Module):
    def __init__(self, ch_in, ch_out, n=1, shortcut=True, groups=1, factor=0.5):
        super().__init__()

        hidden_size = int(ch_out * factor)
        self.conv1 = Conv(ch_in, hidden_size, 1)
        self.conv2 = Conv(ch_in, hidden_size, 1)
        self.conv3 = Conv(2 * hidden_size, ch_out, 1)
        self.m = nn.Sequential(*(Bottleneck(hidden_size, hidden_size) for _ in range(n)))

    def forward(self, x):
        return self.conv3(torch.cat((self.conv1(x), self.m(self.conv2(x))), dim=1))


class SPPF(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size=5):
        super().__init__()

        hidden_size = ch_in // 2
        self.conv1 = Conv(ch_in, hidden_size, 1)
        self.conv2 = Conv(4 * hidden_size, ch_out, 1)
        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, x):
        x = self.conv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        y3 = self.m(y2)
        return self.conv2(torch.cat([x, y1, y2, y3], dim=1))


class Network(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = Conv(3, 64, 3, 2, 2)
        self.conv2 = Conv(64, 128, 3, 2)
        self.c3_1 = C3(128, 128)
        self.conv3 = Conv(128, 256, 3, 2)
        self.c3_2 = C3(256, 256)
        self.conv4 = Conv(256, 512, 3, 2)
        self.c3_3 = C3(512, 512)
        self.conv5 = Conv(512, 1024, 3, 2)
        self.c3_4 = C3(1024, 1024)
        self.sppf = SPPF(1024, 1024, 5)

        self.classifier = nn.Sequential(
            nn.Linear(65536, 100),
            nn.ReLU(),
            nn.Linear(100, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.c3_1(x)
        x = self.conv3(x)
        x = self.c3_2(x)
        x = self.conv4(x)
        x = self.c3_3(x)
        x = self.conv5(x)
        x = self.c3_4(x)
        x = self.sppf(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

打印出网络架构。

model = Network(len(class_names)).to(device)
print(model)
summary(model, input_size=(32, 3, 224, 224))

Network(
(conv1): Conv(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_1): C3(
(conv1): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv3): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_2): C3(
(conv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv4): Conv(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_3): C3(
(conv1): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv5): Conv(
(conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_4): C3(
(conv1): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(sppf): SPPF(
(conv1): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=65536, out_features=100, bias=True)
(1): ReLU()
(2): Linear(in_features=100, out_features=4, bias=True)
)
)

===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
Network [32, 4] –
├─Conv: 1-1 [32, 64, 113, 113] –
│ └─Conv2d: 2-1 [32, 64, 113, 113] 1,728
│ └─BatchNorm2d: 2-2 [32, 64, 113, 113] 128
│ └─SiLU: 2-3 [32, 64, 113, 113] –
├─Conv: 1-2 [32, 128, 57, 57] –
│ └─Conv2d: 2-4 [32, 128, 57, 57] 73,728
│ └─BatchNorm2d: 2-5 [32, 128, 57, 57] 256
│ └─SiLU: 2-6 [32, 128, 57, 57] –
├─C3: 1-3 [32, 128, 57, 57] –
│ └─Conv: 2-7 [32, 64, 57, 57] –
│ │ └─Conv2d: 3-1 [32, 64, 57, 57] 8,192
│ │ └─BatchNorm2d: 3-2 [32, 64, 57, 57] 128
│ │ └─SiLU: 3-3 [32, 64, 57, 57] –
│ └─Conv: 2-8 [32, 64, 57, 57] –
│ │ └─Conv2d: 3-4 [32, 64, 57, 57] 8,192
│ │ └─BatchNorm2d: 3-5 [32, 64, 57, 57] 128
│ │ └─SiLU: 3-6 [32, 64, 57, 57] –
│ └─Sequential: 2-9 [32, 64, 57, 57] –
│ │ └─Bottleneck: 3-7 [32, 64, 57, 57] 20,672
│ └─Conv: 2-10 [32, 128, 57, 57] –
│ │ └─Conv2d: 3-8 [32, 128, 57, 57] 16,384
│ │ └─BatchNorm2d: 3-9 [32, 128, 57, 57] 256
│ │ └─SiLU: 3-10 [32, 128, 57, 57] –
├─Conv: 1-4 [32, 256, 29, 29] –
│ └─Conv2d: 2-11 [32, 256, 29, 29] 294,912
│ └─BatchNorm2d: 2-12 [32, 256, 29, 29] 512
│ └─SiLU: 2-13 [32, 256, 29, 29] –
├─C3: 1-5 [32, 256, 29, 29] –
│ └─Conv: 2-14 [32, 128, 29, 29] –
│ │ └─Conv2d: 3-11 [32, 128, 29, 29] 32,768
│ │ └─BatchNorm2d: 3-12 [32, 128, 29, 29] 256
│ │ └─SiLU: 3-13 [32, 128, 29, 29] –
│ └─Conv: 2-15 [32, 128, 29, 29] –
│ │ └─Conv2d: 3-14 [32, 128, 29, 29] 32,768
│ │ └─BatchNorm2d: 3-15 [32, 128, 29, 29] 256
│ │ └─SiLU: 3-16 [32, 128, 29, 29] –
│ └─Sequential: 2-16 [32, 128, 29, 29] –
│ │ └─Bottleneck: 3-17 [32, 128, 29, 29] 82,304
│ └─Conv: 2-17 [32, 256, 29, 29] –
│ │ └─Conv2d: 3-18 [32, 256, 29, 29] 65,536
│ │ └─BatchNorm2d: 3-19 [32, 256, 29, 29] 512
│ │ └─SiLU: 3-20 [32, 256, 29, 29] –
├─Conv: 1-6 [32, 512, 15, 15] –
│ └─Conv2d: 2-18 [32, 512, 15, 15] 1,179,648
│ └─BatchNorm2d: 2-19 [32, 512, 15, 15] 1,024
│ └─SiLU: 2-20 [32, 512, 15, 15] –
├─C3: 1-7 [32, 512, 15, 15] –
│ └─Conv: 2-21 [32, 256, 15, 15] –
│ │ └─Conv2d: 3-21 [32, 256, 15, 15] 131,072
│ │ └─BatchNorm2d: 3-22 [32, 256, 15, 15] 512
│ │ └─SiLU: 3-23 [32, 256, 15, 15] –
│ └─Conv: 2-22 [32, 256, 15, 15] –
│ │ └─Conv2d: 3-24 [32, 256, 15, 15] 131,072
│ │ └─BatchNorm2d: 3-25 [32, 256, 15, 15] 512
│ │ └─SiLU: 3-26 [32, 256, 15, 15] –
│ └─Sequential: 2-23 [32, 256, 15, 15] –
│ │ └─Bottleneck: 3-27 [32, 256, 15, 15] 328,448
│ └─Conv: 2-24 [32, 512, 15, 15] –
│ │ └─Conv2d: 3-28 [32, 512, 15, 15] 262,144
│ │ └─BatchNorm2d: 3-29 [32, 512, 15, 15] 1,024
│ │ └─SiLU: 3-30 [32, 512, 15, 15] –
├─Conv: 1-8 [32, 1024, 8, 8] –
│ └─Conv2d: 2-25 [32, 1024, 8, 8] 4,718,592
│ └─BatchNorm2d: 2-26 [32, 1024, 8, 8] 2,048
│ └─SiLU: 2-27 [32, 1024, 8, 8] –
├─C3: 1-9 [32, 1024, 8, 8] –
│ └─Conv: 2-28 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-31 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-32 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-33 [32, 512, 8, 8] –
│ └─Conv: 2-29 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-34 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-35 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-36 [32, 512, 8, 8] –
│ └─Sequential: 2-30 [32, 512, 8, 8] –
│ │ └─Bottleneck: 3-37 [32, 512, 8, 8] 1,312,256
│ └─Conv: 2-31 [32, 1024, 8, 8] –
│ │ └─Conv2d: 3-38 [32, 1024, 8, 8] 1,048,576
│ │ └─BatchNorm2d: 3-39 [32, 1024, 8, 8] 2,048
│ │ └─SiLU: 3-40 [32, 1024, 8, 8] –
├─SPPF: 1-10 [32, 1024, 8, 8] –
│ └─Conv: 2-32 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-41 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-42 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-43 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-33 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-34 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-35 [32, 512, 8, 8] –
│ └─Conv: 2-36 [32, 1024, 8, 8] –
│ │ └─Conv2d: 3-44 [32, 1024, 8, 8] 2,097,152
│ │ └─BatchNorm2d: 3-45 [32, 1024, 8, 8] 2,048
│ │ └─SiLU: 3-46 [32, 1024, 8, 8] –
├─Sequential: 1-11 [32, 4] –
│ └─Linear: 2-37 [32, 100] 6,553,700
│ └─ReLU: 2-38 [32, 100] –
│ └─Linear: 2-39 [32, 4] 404
===============================================================================================
Total params: 19,987,832
Trainable params: 19,987,832
Non-trainable params: 0
Total mult-adds (G): 64.43
===============================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 2027.63
Params size (MB): 79.95
Estimated Total Size (MB): 2126.85
===============================================================================================

三、训练函数与测试函数

1.训练函数

def train(train_loader, model, loss_fn, optimizer):
    model.train()
    train_loss, train_acc = 0, 0
    num_batches = len(train_loader)
    size = len(train_loader.dataset)

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

    train_loss /= num_batches
    train_acc /= size

    return train_loss, train_acc

2.测试函数

def test(test_loader, model, loss_fn):
    model.eval()
    test_loss, test_acc = 0, 0
    num_batches = len(test_loader)
    size = len(test_loader.dataset)

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            test_loss += loss.item()
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    test_acc /= size

    return test_loss, test_acc

3.损失函数,学习率与优化函数

设置epochs为50,使用交叉熵损失函数,设置最优权重路径。

epochs = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
best_acc = 0
best_model_path = 'best_p9_model.pth'

train_loss, train_acc = [], []
test_loss, test_acc = [], []

四、正式训练及可视化

for epoch in range(epochs):
    epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
    epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)

    if best_acc < epoch_test_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_loss.append(epoch_train_loss)
    train_acc.append(epoch_train_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

    lr = optimizer.state_dict()['param_groups'][0]['lr']

    print(
        f"Epoch: {epoch + 1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc * 100:.1f},TestLoss: {epoch_test_loss:.3f}, TestAcc: {epoch_test_acc * 100:.1f}, learning_rate: {lr}")
print(f"training finished, save best model to : {best_model_path})")
torch.save(best_model.state_dict(), best_model_path)
print("done")

将每个epoch参数计算并打印,将最优权重保存。

Epoch: 1, TrainLoss: 0.854, TrainAcc: 63.8,TestLoss: 2.602, TestAcc: 27.1, learning_rate: 0.0001
Epoch: 2, TrainLoss: 0.783, TrainAcc: 73.1,TestLoss: 0.471, TestAcc: 75.6, learning_rate: 0.0001
Epoch: 3, TrainLoss: 0.467, TrainAcc: 80.4,TestLoss: 0.363, TestAcc: 84.4, learning_rate: 0.0001
Epoch: 4, TrainLoss: 0.431, TrainAcc: 89.3,TestLoss: 0.466, TestAcc: 84.4, learning_rate: 0.0001
Epoch: 5, TrainLoss: 0.328, TrainAcc: 89.9,TestLoss: 0.343, TestAcc: 89.3, learning_rate: 0.0001
Epoch: 6, TrainLoss: 0.382, TrainAcc: 87.2,TestLoss: 0.260, TestAcc: 87.1, learning_rate: 0.0001
Epoch: 7, TrainLoss: 0.258, TrainAcc: 90.9,TestLoss: 0.299, TestAcc: 88.4, learning_rate: 0.0001
Epoch: 8, TrainLoss: 0.163, TrainAcc: 93.9,TestLoss: 0.409, TestAcc: 87.1, learning_rate: 0.0001
Epoch: 9, TrainLoss: 0.174, TrainAcc: 93.0,TestLoss: 0.309, TestAcc: 88.0, learning_rate: 0.0001
Epoch: 10, TrainLoss: 0.283, TrainAcc: 92.3,TestLoss: 0.306, TestAcc: 88.4, learning_rate: 0.0001
Epoch: 11, TrainLoss: 0.246, TrainAcc: 95.0,TestLoss: 0.162, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 12, TrainLoss: 0.172, TrainAcc: 93.9,TestLoss: 0.241, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 13, TrainLoss: 0.114, TrainAcc: 95.8,TestLoss: 0.148, TestAcc: 96.0, learning_rate: 0.0001
Epoch: 14, TrainLoss: 0.082, TrainAcc: 97.4,TestLoss: 0.219, TestAcc: 92.0, learning_rate: 0.0001
Epoch: 15, TrainLoss: 0.111, TrainAcc: 96.4,TestLoss: 0.238, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 16, TrainLoss: 0.130, TrainAcc: 97.9,TestLoss: 0.197, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 17, TrainLoss: 0.206, TrainAcc: 97.4,TestLoss: 0.326, TestAcc: 90.2, learning_rate: 0.0001
Epoch: 18, TrainLoss: 0.101, TrainAcc: 96.7,TestLoss: 0.147, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 19, TrainLoss: 0.080, TrainAcc: 98.4,TestLoss: 0.207, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 20, TrainLoss: 0.094, TrainAcc: 97.8,TestLoss: 0.345, TestAcc: 90.7, learning_rate: 0.0001
Epoch: 21, TrainLoss: 0.217, TrainAcc: 97.0,TestLoss: 0.211, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 22, TrainLoss: 0.065, TrainAcc: 98.3,TestLoss: 0.233, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 23, TrainLoss: 0.023, TrainAcc: 99.1,TestLoss: 0.256, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 24, TrainLoss: 0.115, TrainAcc: 98.4,TestLoss: 0.224, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 25, TrainLoss: 0.070, TrainAcc: 97.7,TestLoss: 0.233, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 26, TrainLoss: 0.042, TrainAcc: 99.0,TestLoss: 0.239, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 27, TrainLoss: 0.033, TrainAcc: 99.0,TestLoss: 0.277, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 28, TrainLoss: 0.052, TrainAcc: 98.2,TestLoss: 0.192, TestAcc: 96.0, learning_rate: 0.0001
Epoch: 29, TrainLoss: 0.075, TrainAcc: 99.3,TestLoss: 0.288, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 30, TrainLoss: 0.178, TrainAcc: 94.3,TestLoss: 0.249, TestAcc: 91.1, learning_rate: 0.0001
Epoch: 31, TrainLoss: 0.101, TrainAcc: 97.6,TestLoss: 0.247, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 32, TrainLoss: 0.051, TrainAcc: 98.1,TestLoss: 0.226, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 33, TrainLoss: 0.095, TrainAcc: 99.0,TestLoss: 0.280, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 34, TrainLoss: 0.241, TrainAcc: 96.9,TestLoss: 0.209, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 35, TrainLoss: 0.076, TrainAcc: 98.2,TestLoss: 0.175, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 36, TrainLoss: 0.032, TrainAcc: 99.1,TestLoss: 0.222, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 37, TrainLoss: 0.016, TrainAcc: 99.3,TestLoss: 0.267, TestAcc: 92.4, learning_rate: 0.0001
Epoch: 38, TrainLoss: 0.041, TrainAcc: 98.7,TestLoss: 0.226, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 39, TrainLoss: 0.035, TrainAcc: 99.0,TestLoss: 0.173, TestAcc: 95.6, learning_rate: 0.0001
Epoch: 40, TrainLoss: 0.014, TrainAcc: 99.3,TestLoss: 0.229, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 41, TrainLoss: 0.086, TrainAcc: 98.7,TestLoss: 0.197, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 42, TrainLoss: 0.097, TrainAcc: 96.7,TestLoss: 0.451, TestAcc: 91.1, learning_rate: 0.0001
Epoch: 43, TrainLoss: 0.035, TrainAcc: 98.7,TestLoss: 0.225, TestAcc: 92.9, learning_rate: 0.0001
Epoch: 44, TrainLoss: 0.043, TrainAcc: 99.4,TestLoss: 0.248, TestAcc: 92.4, learning_rate: 0.0001
Epoch: 45, TrainLoss: 0.235, TrainAcc: 97.2,TestLoss: 0.277, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 46, TrainLoss: 0.123, TrainAcc: 96.2,TestLoss: 0.180, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 47, TrainLoss: 0.033, TrainAcc: 98.9,TestLoss: 0.143, TestAcc: 97.3, learning_rate: 0.0001
Epoch: 48, TrainLoss: 0.189, TrainAcc: 99.3,TestLoss: 0.188, TestAcc: 95.6, learning_rate: 0.0001
Epoch: 49, TrainLoss: 0.041, TrainAcc: 98.9,TestLoss: 0.257, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 50, TrainLoss: 0.036, TrainAcc: 99.0,TestLoss: 0.263, TestAcc: 93.8, learning_rate: 0.0001
training finished, save best model to : best_p9_model.pth)
done

可视化:

import matplotlib.pyplot as plt
# 隐藏警告
import warnings

warnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
model.load_state_dict(torch.load(best_model_path))
model.to(device)

在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值