- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
- 🚀 文章来源:K同学的学习圈子
环境配置:
Python version: 3.8.17 (default, Jul 5 2023, 20:44:21) [MSC v.1916 64 bit (AMD64)]
Pytorch version: 2.0.1+cu117
Torchvision version: 0.15.2+cu117
CUDA is available: True
Using device: cuda
本次天气数据集由K同学提供,若有需要请联系K同学
一、前期准备
1.导入所需要的包并设置gpu
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pathlib, random, copy
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
结果打印:
cuda
2.数据预处理及可视化
将下载好的数据集放在同一文件夹下,设置好路径,读取所有图片,随机打印出五张图片信息,并随机打印出二十张图片。
image_path = 'weather_photos'
# 所有图像的列表
image_list = list(pathlib.Path(image_path).glob('*/*'))
# 随机打印几张图像的信息
for _ in range(5):
image = random.choice(image_list)
print(f"{str(image)}, shape is: {np.array(Image.open(str(image))).shape}")
# 查看随机的20张图像
plt.figure(figsize=(20, 4))
for i in range(20):
plt.subplot(2, 10, i + 1)
plt.axis('off')
image = random.choice(image_list)
plt.title(image.parts[-2])
plt.imshow(Image.open(str(image)))
展示如下:
weather_photos\cloudy\cloudy167.jpg, shape is: (167, 222, 3)
weather_photos\cloudy\cloudy290.jpg, shape is: (185, 298, 3)
weather_photos\sunrise\sunrise324.jpg, shape is: (168, 252, 3)
weather_photos\sunrise\sunrise67.jpg, shape is: (194, 259, 3)
weather_photos\rain\rain142.jpg, shape is: (420, 640, 3)
进行预处理,打印出类名
img_transform = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
dataset = datasets.ImageFolder(image_path, transform=img_transform)
class_names = [k for k in dataset.class_to_idx]
print(class_names)
类别名:
[‘cloudy’, ‘rain’, ‘shine’, ‘sunrise’]
按比例划分训练集与测试集,设置batch_size大小,定义测试集与训练集。
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
二、模型搭建及可视化
1.相关知识补充
YOLOv5是一种轻量级的目标检测算法,其backbone是指用于特征提取的主干网络。YOLOv5的backbone采用了CSPDarknet53,其中CSP表示Cross Stage Partial网络,Darknet是一个轻量级的卷积神经网络。
CSPDarknet53是YOLOv5中的关键组件,它是对Darknet53网络的改进。CSPDarknet53通过引入CSP结构和SPP结构来提高网络的性能和效率。
CSP结构是指将输入特征图分成两个分支,一个分支直接进行卷积操作,另一个分支进行通道分割和再连接操作,从而减少了计算量和内存占用。这种结构可以有效提升网络在目标检测任务中的表达能力。
SPP结构是指空间金字塔池化(Spatial Pyramid Pooling),它可以在不同尺度上提取特征,从而使网络能够检测不同大小的目标。这种结构可以提高网络的感受野,提升目标检测的准确性。
YOLOv5的backbone的架构可以概括为以下几个步骤:
1.输入图像经过一系列卷积操作,提取低级特征。
2.通过CSP结构将低级特征分成两个分支,其中一个分支直接进行卷积操作,另一个分支进行通道分割和再连接操作。
3.分支输出的特征图经过多次残差结构和CSP结构,逐渐提取更高级的特征。
4.最后通过SPP结构进行空间金字塔池化,将不同尺度的特征进行融合。
5.最终得到的特征图通过卷积层和全连接层进行目标检测的预测。
这里使用K同学提供的架构图来帮助理解。
YOLOv5由以下几个主要组件构成:
Backbone:YOLOv5使用一个强大的骨干网络作为其主干架构,用于提取图像特征。YOLOv5的默认主干网络是CSPDarknet53,它是一个深层的卷积神经网络,具有较强的特征提取能力。
Neck:YOLOv5使用了一种称为PANet的特征金字塔网络结构,用于融合不同尺度的特征图。PANet通过上下文感知模块和特征金字塔模块来增强特征图的表示能力,以便更好地检测不同大小的目标。
Head:YOLOv5的头部结构包含了一系列卷积层和全连接层,用于生成目标检测的预测结果。头部结构负责将特征图转化为边界框的位置和类别预测。
Loss Function:YOLOv5使用一种称为YOLOv3-loss的多任务损失函数,用于训练模型。该损失函数包括了边界框位置损失、目标置信度损失和类别预测损失,以最小化目标检测的误差。
2.代码及模型数据打印
class Conv(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride=1, padding=None, groups=1, activation=True):
super().__init__()
self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, pad(kernel_size, padding), groups=groups, bias=False)
self.bn = nn.BatchNorm2d(ch_out)
self.act = nn.SiLU() if activation is True else (
activation if isinstance(activation, nn.Module) else nn.Identity())
def forward(self, x):
x = self.act(self.bn(self.conv(x)))
return x
class Bottleneck(nn.Module):
def __init__(self, ch_in, ch_out, shortcut=True, groups=1, factor=0.5):
super().__init__()
hidden_size = int(ch_out * factor)
self.conv1 = Conv(ch_in, hidden_size, 1)
self.conv2 = Conv(hidden_size, ch_out, 3)
self.add = shortcut and ch_in == ch_out
def forward(self, x):
return x + self.conv2(self.conv1(x)) if self.add else self.conv2(self.conv1(x))
class C3(nn.Module):
def __init__(self, ch_in, ch_out, n=1, shortcut=True, groups=1, factor=0.5):
super().__init__()
hidden_size = int(ch_out * factor)
self.conv1 = Conv(ch_in, hidden_size, 1)
self.conv2 = Conv(ch_in, hidden_size, 1)
self.conv3 = Conv(2 * hidden_size, ch_out, 1)
self.m = nn.Sequential(*(Bottleneck(hidden_size, hidden_size) for _ in range(n)))
def forward(self, x):
return self.conv3(torch.cat((self.conv1(x), self.m(self.conv2(x))), dim=1))
class SPPF(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size=5):
super().__init__()
hidden_size = ch_in // 2
self.conv1 = Conv(ch_in, hidden_size, 1)
self.conv2 = Conv(4 * hidden_size, ch_out, 1)
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
x = self.conv1(x)
y1 = self.m(x)
y2 = self.m(y1)
y3 = self.m(y2)
return self.conv2(torch.cat([x, y1, y2, y3], dim=1))
class Network(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv1 = Conv(3, 64, 3, 2, 2)
self.conv2 = Conv(64, 128, 3, 2)
self.c3_1 = C3(128, 128)
self.conv3 = Conv(128, 256, 3, 2)
self.c3_2 = C3(256, 256)
self.conv4 = Conv(256, 512, 3, 2)
self.c3_3 = C3(512, 512)
self.conv5 = Conv(512, 1024, 3, 2)
self.c3_4 = C3(1024, 1024)
self.sppf = SPPF(1024, 1024, 5)
self.classifier = nn.Sequential(
nn.Linear(65536, 100),
nn.ReLU(),
nn.Linear(100, num_classes)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.c3_1(x)
x = self.conv3(x)
x = self.c3_2(x)
x = self.conv4(x)
x = self.c3_3(x)
x = self.conv5(x)
x = self.c3_4(x)
x = self.sppf(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
打印出网络架构。
model = Network(len(class_names)).to(device)
print(model)
summary(model, input_size=(32, 3, 224, 224))
Network(
(conv1): Conv(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_1): C3(
(conv1): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv3): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_2): C3(
(conv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv4): Conv(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_3): C3(
(conv1): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(conv5): Conv(
(conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(c3_4): C3(
(conv1): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv3): Conv(
(conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): Sequential(
(0): Bottleneck(
(conv1): Conv(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
)
)
)
(sppf): SPPF(
(conv1): Conv(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(conv2): Conv(
(conv): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): SiLU()
)
(m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=65536, out_features=100, bias=True)
(1): ReLU()
(2): Linear(in_features=100, out_features=4, bias=True)
)
)
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
Network [32, 4] –
├─Conv: 1-1 [32, 64, 113, 113] –
│ └─Conv2d: 2-1 [32, 64, 113, 113] 1,728
│ └─BatchNorm2d: 2-2 [32, 64, 113, 113] 128
│ └─SiLU: 2-3 [32, 64, 113, 113] –
├─Conv: 1-2 [32, 128, 57, 57] –
│ └─Conv2d: 2-4 [32, 128, 57, 57] 73,728
│ └─BatchNorm2d: 2-5 [32, 128, 57, 57] 256
│ └─SiLU: 2-6 [32, 128, 57, 57] –
├─C3: 1-3 [32, 128, 57, 57] –
│ └─Conv: 2-7 [32, 64, 57, 57] –
│ │ └─Conv2d: 3-1 [32, 64, 57, 57] 8,192
│ │ └─BatchNorm2d: 3-2 [32, 64, 57, 57] 128
│ │ └─SiLU: 3-3 [32, 64, 57, 57] –
│ └─Conv: 2-8 [32, 64, 57, 57] –
│ │ └─Conv2d: 3-4 [32, 64, 57, 57] 8,192
│ │ └─BatchNorm2d: 3-5 [32, 64, 57, 57] 128
│ │ └─SiLU: 3-6 [32, 64, 57, 57] –
│ └─Sequential: 2-9 [32, 64, 57, 57] –
│ │ └─Bottleneck: 3-7 [32, 64, 57, 57] 20,672
│ └─Conv: 2-10 [32, 128, 57, 57] –
│ │ └─Conv2d: 3-8 [32, 128, 57, 57] 16,384
│ │ └─BatchNorm2d: 3-9 [32, 128, 57, 57] 256
│ │ └─SiLU: 3-10 [32, 128, 57, 57] –
├─Conv: 1-4 [32, 256, 29, 29] –
│ └─Conv2d: 2-11 [32, 256, 29, 29] 294,912
│ └─BatchNorm2d: 2-12 [32, 256, 29, 29] 512
│ └─SiLU: 2-13 [32, 256, 29, 29] –
├─C3: 1-5 [32, 256, 29, 29] –
│ └─Conv: 2-14 [32, 128, 29, 29] –
│ │ └─Conv2d: 3-11 [32, 128, 29, 29] 32,768
│ │ └─BatchNorm2d: 3-12 [32, 128, 29, 29] 256
│ │ └─SiLU: 3-13 [32, 128, 29, 29] –
│ └─Conv: 2-15 [32, 128, 29, 29] –
│ │ └─Conv2d: 3-14 [32, 128, 29, 29] 32,768
│ │ └─BatchNorm2d: 3-15 [32, 128, 29, 29] 256
│ │ └─SiLU: 3-16 [32, 128, 29, 29] –
│ └─Sequential: 2-16 [32, 128, 29, 29] –
│ │ └─Bottleneck: 3-17 [32, 128, 29, 29] 82,304
│ └─Conv: 2-17 [32, 256, 29, 29] –
│ │ └─Conv2d: 3-18 [32, 256, 29, 29] 65,536
│ │ └─BatchNorm2d: 3-19 [32, 256, 29, 29] 512
│ │ └─SiLU: 3-20 [32, 256, 29, 29] –
├─Conv: 1-6 [32, 512, 15, 15] –
│ └─Conv2d: 2-18 [32, 512, 15, 15] 1,179,648
│ └─BatchNorm2d: 2-19 [32, 512, 15, 15] 1,024
│ └─SiLU: 2-20 [32, 512, 15, 15] –
├─C3: 1-7 [32, 512, 15, 15] –
│ └─Conv: 2-21 [32, 256, 15, 15] –
│ │ └─Conv2d: 3-21 [32, 256, 15, 15] 131,072
│ │ └─BatchNorm2d: 3-22 [32, 256, 15, 15] 512
│ │ └─SiLU: 3-23 [32, 256, 15, 15] –
│ └─Conv: 2-22 [32, 256, 15, 15] –
│ │ └─Conv2d: 3-24 [32, 256, 15, 15] 131,072
│ │ └─BatchNorm2d: 3-25 [32, 256, 15, 15] 512
│ │ └─SiLU: 3-26 [32, 256, 15, 15] –
│ └─Sequential: 2-23 [32, 256, 15, 15] –
│ │ └─Bottleneck: 3-27 [32, 256, 15, 15] 328,448
│ └─Conv: 2-24 [32, 512, 15, 15] –
│ │ └─Conv2d: 3-28 [32, 512, 15, 15] 262,144
│ │ └─BatchNorm2d: 3-29 [32, 512, 15, 15] 1,024
│ │ └─SiLU: 3-30 [32, 512, 15, 15] –
├─Conv: 1-8 [32, 1024, 8, 8] –
│ └─Conv2d: 2-25 [32, 1024, 8, 8] 4,718,592
│ └─BatchNorm2d: 2-26 [32, 1024, 8, 8] 2,048
│ └─SiLU: 2-27 [32, 1024, 8, 8] –
├─C3: 1-9 [32, 1024, 8, 8] –
│ └─Conv: 2-28 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-31 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-32 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-33 [32, 512, 8, 8] –
│ └─Conv: 2-29 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-34 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-35 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-36 [32, 512, 8, 8] –
│ └─Sequential: 2-30 [32, 512, 8, 8] –
│ │ └─Bottleneck: 3-37 [32, 512, 8, 8] 1,312,256
│ └─Conv: 2-31 [32, 1024, 8, 8] –
│ │ └─Conv2d: 3-38 [32, 1024, 8, 8] 1,048,576
│ │ └─BatchNorm2d: 3-39 [32, 1024, 8, 8] 2,048
│ │ └─SiLU: 3-40 [32, 1024, 8, 8] –
├─SPPF: 1-10 [32, 1024, 8, 8] –
│ └─Conv: 2-32 [32, 512, 8, 8] –
│ │ └─Conv2d: 3-41 [32, 512, 8, 8] 524,288
│ │ └─BatchNorm2d: 3-42 [32, 512, 8, 8] 1,024
│ │ └─SiLU: 3-43 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-33 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-34 [32, 512, 8, 8] –
│ └─MaxPool2d: 2-35 [32, 512, 8, 8] –
│ └─Conv: 2-36 [32, 1024, 8, 8] –
│ │ └─Conv2d: 3-44 [32, 1024, 8, 8] 2,097,152
│ │ └─BatchNorm2d: 3-45 [32, 1024, 8, 8] 2,048
│ │ └─SiLU: 3-46 [32, 1024, 8, 8] –
├─Sequential: 1-11 [32, 4] –
│ └─Linear: 2-37 [32, 100] 6,553,700
│ └─ReLU: 2-38 [32, 100] –
│ └─Linear: 2-39 [32, 4] 404
===============================================================================================
Total params: 19,987,832
Trainable params: 19,987,832
Non-trainable params: 0
Total mult-adds (G): 64.43
===============================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 2027.63
Params size (MB): 79.95
Estimated Total Size (MB): 2126.85
===============================================================================================
三、训练函数与测试函数
1.训练函数
def train(train_loader, model, loss_fn, optimizer):
model.train()
train_loss, train_acc = 0, 0
num_batches = len(train_loader)
size = len(train_loader.dataset)
for x, y in train_loader:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss /= num_batches
train_acc /= size
return train_loss, train_acc
2.测试函数
def test(test_loader, model, loss_fn):
model.eval()
test_loss, test_acc = 0, 0
num_batches = len(test_loader)
size = len(test_loader.dataset)
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
test_loss += loss.item()
test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
test_acc /= size
return test_loss, test_acc
3.损失函数,学习率与优化函数
设置epochs为50,使用交叉熵损失函数,设置最优权重路径。
epochs = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
best_acc = 0
best_model_path = 'best_p9_model.pth'
train_loss, train_acc = [], []
test_loss, test_acc = [], []
四、正式训练及可视化
for epoch in range(epochs):
epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)
if best_acc < epoch_test_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
train_loss.append(epoch_train_loss)
train_acc.append(epoch_train_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
lr = optimizer.state_dict()['param_groups'][0]['lr']
print(
f"Epoch: {epoch + 1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc * 100:.1f},TestLoss: {epoch_test_loss:.3f}, TestAcc: {epoch_test_acc * 100:.1f}, learning_rate: {lr}")
print(f"training finished, save best model to : {best_model_path})")
torch.save(best_model.state_dict(), best_model_path)
print("done")
将每个epoch参数计算并打印,将最优权重保存。
Epoch: 1, TrainLoss: 0.854, TrainAcc: 63.8,TestLoss: 2.602, TestAcc: 27.1, learning_rate: 0.0001
Epoch: 2, TrainLoss: 0.783, TrainAcc: 73.1,TestLoss: 0.471, TestAcc: 75.6, learning_rate: 0.0001
Epoch: 3, TrainLoss: 0.467, TrainAcc: 80.4,TestLoss: 0.363, TestAcc: 84.4, learning_rate: 0.0001
Epoch: 4, TrainLoss: 0.431, TrainAcc: 89.3,TestLoss: 0.466, TestAcc: 84.4, learning_rate: 0.0001
Epoch: 5, TrainLoss: 0.328, TrainAcc: 89.9,TestLoss: 0.343, TestAcc: 89.3, learning_rate: 0.0001
Epoch: 6, TrainLoss: 0.382, TrainAcc: 87.2,TestLoss: 0.260, TestAcc: 87.1, learning_rate: 0.0001
Epoch: 7, TrainLoss: 0.258, TrainAcc: 90.9,TestLoss: 0.299, TestAcc: 88.4, learning_rate: 0.0001
Epoch: 8, TrainLoss: 0.163, TrainAcc: 93.9,TestLoss: 0.409, TestAcc: 87.1, learning_rate: 0.0001
Epoch: 9, TrainLoss: 0.174, TrainAcc: 93.0,TestLoss: 0.309, TestAcc: 88.0, learning_rate: 0.0001
Epoch: 10, TrainLoss: 0.283, TrainAcc: 92.3,TestLoss: 0.306, TestAcc: 88.4, learning_rate: 0.0001
Epoch: 11, TrainLoss: 0.246, TrainAcc: 95.0,TestLoss: 0.162, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 12, TrainLoss: 0.172, TrainAcc: 93.9,TestLoss: 0.241, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 13, TrainLoss: 0.114, TrainAcc: 95.8,TestLoss: 0.148, TestAcc: 96.0, learning_rate: 0.0001
Epoch: 14, TrainLoss: 0.082, TrainAcc: 97.4,TestLoss: 0.219, TestAcc: 92.0, learning_rate: 0.0001
Epoch: 15, TrainLoss: 0.111, TrainAcc: 96.4,TestLoss: 0.238, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 16, TrainLoss: 0.130, TrainAcc: 97.9,TestLoss: 0.197, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 17, TrainLoss: 0.206, TrainAcc: 97.4,TestLoss: 0.326, TestAcc: 90.2, learning_rate: 0.0001
Epoch: 18, TrainLoss: 0.101, TrainAcc: 96.7,TestLoss: 0.147, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 19, TrainLoss: 0.080, TrainAcc: 98.4,TestLoss: 0.207, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 20, TrainLoss: 0.094, TrainAcc: 97.8,TestLoss: 0.345, TestAcc: 90.7, learning_rate: 0.0001
Epoch: 21, TrainLoss: 0.217, TrainAcc: 97.0,TestLoss: 0.211, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 22, TrainLoss: 0.065, TrainAcc: 98.3,TestLoss: 0.233, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 23, TrainLoss: 0.023, TrainAcc: 99.1,TestLoss: 0.256, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 24, TrainLoss: 0.115, TrainAcc: 98.4,TestLoss: 0.224, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 25, TrainLoss: 0.070, TrainAcc: 97.7,TestLoss: 0.233, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 26, TrainLoss: 0.042, TrainAcc: 99.0,TestLoss: 0.239, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 27, TrainLoss: 0.033, TrainAcc: 99.0,TestLoss: 0.277, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 28, TrainLoss: 0.052, TrainAcc: 98.2,TestLoss: 0.192, TestAcc: 96.0, learning_rate: 0.0001
Epoch: 29, TrainLoss: 0.075, TrainAcc: 99.3,TestLoss: 0.288, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 30, TrainLoss: 0.178, TrainAcc: 94.3,TestLoss: 0.249, TestAcc: 91.1, learning_rate: 0.0001
Epoch: 31, TrainLoss: 0.101, TrainAcc: 97.6,TestLoss: 0.247, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 32, TrainLoss: 0.051, TrainAcc: 98.1,TestLoss: 0.226, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 33, TrainLoss: 0.095, TrainAcc: 99.0,TestLoss: 0.280, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 34, TrainLoss: 0.241, TrainAcc: 96.9,TestLoss: 0.209, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 35, TrainLoss: 0.076, TrainAcc: 98.2,TestLoss: 0.175, TestAcc: 94.2, learning_rate: 0.0001
Epoch: 36, TrainLoss: 0.032, TrainAcc: 99.1,TestLoss: 0.222, TestAcc: 93.3, learning_rate: 0.0001
Epoch: 37, TrainLoss: 0.016, TrainAcc: 99.3,TestLoss: 0.267, TestAcc: 92.4, learning_rate: 0.0001
Epoch: 38, TrainLoss: 0.041, TrainAcc: 98.7,TestLoss: 0.226, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 39, TrainLoss: 0.035, TrainAcc: 99.0,TestLoss: 0.173, TestAcc: 95.6, learning_rate: 0.0001
Epoch: 40, TrainLoss: 0.014, TrainAcc: 99.3,TestLoss: 0.229, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 41, TrainLoss: 0.086, TrainAcc: 98.7,TestLoss: 0.197, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 42, TrainLoss: 0.097, TrainAcc: 96.7,TestLoss: 0.451, TestAcc: 91.1, learning_rate: 0.0001
Epoch: 43, TrainLoss: 0.035, TrainAcc: 98.7,TestLoss: 0.225, TestAcc: 92.9, learning_rate: 0.0001
Epoch: 44, TrainLoss: 0.043, TrainAcc: 99.4,TestLoss: 0.248, TestAcc: 92.4, learning_rate: 0.0001
Epoch: 45, TrainLoss: 0.235, TrainAcc: 97.2,TestLoss: 0.277, TestAcc: 91.6, learning_rate: 0.0001
Epoch: 46, TrainLoss: 0.123, TrainAcc: 96.2,TestLoss: 0.180, TestAcc: 93.8, learning_rate: 0.0001
Epoch: 47, TrainLoss: 0.033, TrainAcc: 98.9,TestLoss: 0.143, TestAcc: 97.3, learning_rate: 0.0001
Epoch: 48, TrainLoss: 0.189, TrainAcc: 99.3,TestLoss: 0.188, TestAcc: 95.6, learning_rate: 0.0001
Epoch: 49, TrainLoss: 0.041, TrainAcc: 98.9,TestLoss: 0.257, TestAcc: 95.1, learning_rate: 0.0001
Epoch: 50, TrainLoss: 0.036, TrainAcc: 99.0,TestLoss: 0.263, TestAcc: 93.8, learning_rate: 0.0001
training finished, save best model to : best_p9_model.pth)
done
可视化:
import matplotlib.pyplot as plt
# 隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
model.load_state_dict(torch.load(best_model_path))
model.to(device)