本文为🔗365天深度学习训练营
中的学习记录博客** 原作者:K同学啊 | 接辅导、项目定制**
我的环境:
系统环境:win10
语言环境:Python3.9.13
编译器:jupyter notebook
深度学习环境:Pytorch 2.3.0+cpu,torchvision 0.18.0+cpu
一、前言
探索一下深度学习在医学领域的应用,乳腺癌是女性最常见的癌症形式,浸润性导管癌(IDC)是最常见的乳腺癌形式。准确识别和分类乳腺癌亚型是一项重要的临床任务,利用深度学习方法识别可以有效节省时间并减少错误。 我们的数据集是由多张以 40 倍扫描的乳腺癌 (BCa) 标本的完整载玻片图像组成。
二、前期准备
1、设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
import torch.nn.functional as F
warnings.filterwarnings("ignore") #忽略警告信息
# 我这里是CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
代码输出:
device(type='cpu')
2、导入数据
# 导入数据
data_dir='./J3-1/'
data_dir=pathlib.Path(data_dir)
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split('\\')[1] for path in data_paths]
classNames
代码输出:
['0', '1']
num_classes=len(classNames)
num_classes
代码输出:
2
train_transforms = transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean = [0.485,0.456,0.406],
std = [0.229,0.224,0.225]
)
])
test_transforms = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
transforms.Normalize(
mean = [0.485,0.456,0.406],
std = [0.229,0.224,0.225]
)
])
total_data = datasets.ImageFolder(data_dir,transform = train_transforms)
total_data
代码输出:
Dataset ImageFolder Number of datapoints: 13403 Root location: J3-1 StandardTransform Transform: Compose( Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True) RandomHorizontalFlip(p=0.5) ToTensor() Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
total_data.class_to_idx
代码输出:
{'0': 0, '1': 1}
3、划分数据集
# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print(train_dataset)
print(test_dataset)
batch_size = 8
train_dl = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
#num_workers=1
)
test_dl = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=True,
#num_workers=1
)
for X, y in test_dl:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
代码输出:
<torch.utils.data.dataset.Subset object at 0x0000024D533D34F0> <torch.utils.data.dataset.Subset object at 0x0000024D533D3EB0> Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224]) Shape of y: torch.Size([8]) torch.int64
三、 搭建网络模型
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
1、DenseLayer模块
self.add_module()函数用于向类中添加一个子模块。在这段代码中, self.add_module(‘name’,module)被用于将不同的模块添加到 DenseLayer 类中。这些模块可以是任何继承自 nn.Modu1e的子类,例如批标准化层(nn.BatchNorm2d)、ReLU 活函数(nn.ReLu)、卷积层(nn.conv2d)等。
该函数的参数包括一个字符串’name’,用于给该模块命名,以及一个模块实例 module,用于表示要添加的模块对象。在代码中,每个模块都按顺序添加到DenseLayer类中。
通过使用 self.add_module()函数,这些子模块被存储在 DenseLayer类的内部,成为该类的属性。这样,在类的其他方法中,可以通过引用这些属性来访问和操作这些模块,例如在forward()方
法中使用 super().forward(x)调用父类模块的前向传播方法。总之,self.add_module()函数的作用是将子模块添加到类中,并为这些模块提供属性名以便后续引用和操作。
# BN+ReLU+1x1Conv+BN+ReLU+3x3Conv结构,最后也加入dropout层以用于训练过程
class DenseLayer(nn.Sequential):
'''Basic unit of DenseBlock (using bottleneck layer)'''
def __init__(self,in_channel,growth_rate,bn_size,drop_rate):
super(DenseLayer,self).__init__()
self.add_module('norm1',nn.BatchNorm2d(in_channel))
self.add_module('relu1',nn.ReLU(inplace=True))
self.add_module('conv1',nn.Conv2d(in_channel, bn_size*growth_rate,kernel_size=1,stride=1,bias=False))
self.add_module('norm2',nn.BatchNorm2d(bn_size*growth_rate))
self.add_module('relu2',nn.ReLU(inplace=True))
self.add_module('conv2',nn.Conv2d(bn_size*growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False))
self.drop_rate = drop_rate
def forward(self,x):
new_features =super(DenseLayer,self).forward(x)
if self.drop_rate>0:
new_features = F.dropout(new_features,p=self.drop_rate,training=self.training)
return torch.cat([x,new_features],1)
2、DenseBlock模块
# 内部是密集连接方式(输入特征数线性增长)
class DenseBlock(nn.Sequential):
def __init__(self,num_layers,in_channel,bn_size,growth_rate,drop_rate):
super(DenseBlock,self).__init__()
for i in range(num_layers):
layer = DenseLayer(in_channel+i*growth_rate,growth_rate,bn_size,drop_rate)
self.add_module('denselayer%d'%(i+1,),layer)
3、Transition模块
# 实现Transition层,主要是一个卷积层和一个池化层
'''Transition layer between two adjacent DenseBlock'''
class Transition(nn.Sequential):
def __init__(self,in_channel,out_channel):
super(Transition,self).__init__()
self.add_module('norm',nn.BatchNorm2d(in_channel))
self.add_module('relu',nn.ReLU(inplace=True))
self.add_module('conv',nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False))
self.add_module('pool',nn.AvgPool2d(2,stride=2))
4、构建DenseNet
nn.Sequential 是PyTorch中的一个模型容器,它按照给定的顺序依次执行一系列的神经网络模块(layers)。在构建神经网络时,我们可以使用nn.Sequential来简化代码。
OrderedDict 是Python中的一种有序字典数据结构,它保留了元素添加的顺序。在神经网络中,我们可以使用 OrderedDict 来指定模型的层次结构。
# 实现DenseNet网络
# DenseNet-BC model
class DenseNet(nn.Module):
def __init__(self,growth_rate=32,block_config=(6,12,24,16),init_channel=64,bn_size=4,compression_rate=0.5,drop_rate=0,num_classes=1000):
"""
:param growth_rate:(int) number of filters used in DenseLayer.'k' in the paper
:param block_config:(list of 4 ints) number of layers in eatch DenseBlock
:param num_init_features:(int) number of filters in the first Conv2d
:param bn_size:(int) the factor using in the bottleneck layer
:param compression_rate: (float) the compression rate used in Transition Layer
:param drop_rate:(float) the drop rate after each DenseLayer
:param num_classes:(int) number of classes for classification
"""
super(DenseNet,self).__init__()
# first Conv2d
self.features =nn.Sequential(OrderedDict([
('conv0',nn.Conv2d(3,init_channel,kernel_size=7,stride=2,padding=3,bias=False)),
('norm0',nn.BatchNorm2d(init_channel)),
('relu0',nn.ReLU(inplace=True)),
('pool0',nn.MaxPool2d(3,stride=2,padding=1))
]))
# DenseBlock
num_features =init_channel
for i,num_layers in enumerate(block_config):
block=DenseBlock(num_layers,num_features,bn_size,growth_rate,drop_rate)
self.features.add_module('denseblock%d'%(i+1),block)
num_features += num_layers*growth_rate
if i != len(block_config)-1:
transition = Transition(num_features,int(num_features*compression_rate))
self.features.add_module('transition%d'%(i+1),transition)
num_features =int(num_features*compression_rate)
# final bn+ReLU
self.features.add_module('norm5',nn.BatchNorm2d(num_features))
self.features.add_module('relu5',nn.ReLU(inplace=True))
#分类层
#classification layer
self.classifier =nn.Linear(num_features,num_classes)
#参数初始化
#params initialization
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant(m.bias,0)
nn.init.constant(m.weight,1)
elif isinstance(m,nn.Linear):
nn.init.constant_(m.bias,0)
def forward(self,x):
x=self.features(x)
x=F.avg_pool2d(x,7,stride=1).view(x.size(0),-1)
x=self.classifier(x)
return x
5、构建densenet121
device ="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
densenet121=DenseNet(init_channel=64,
growth_rate=32,
block_config=(6,12,24,16),
num_classes=len(classNames))
model = densenet121.to(device)
model
代码输出:
Using cpu device
DenseNet( (features): Sequential( (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu0): ReLU(inplace=True) (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (denseblock1): DenseBlock( (denselayer1): DenseLayer( (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): DenseLayer( (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): DenseLayer( (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): DenseLayer( (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): DenseLayer( (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): DenseLayer( (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition1): Transition( (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock2): DenseBlock( (denselayer1): DenseLayer( (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): DenseLayer( (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): DenseLayer( (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): DenseLayer( (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): DenseLayer( (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): DenseLayer( (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): DenseLayer( (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): DenseLayer( (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): DenseLayer( (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): DenseLayer( (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): DenseLayer( (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): DenseLayer( (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition2): Transition( (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock3): DenseBlock( (denselayer1): DenseLayer( (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): DenseLayer( (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): DenseLayer( (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): DenseLayer( (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): DenseLayer( (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): DenseLayer( (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): DenseLayer( (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): DenseLayer( (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): DenseLayer( (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): DenseLayer( (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): DenseLayer( (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): DenseLayer( (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer13): DenseLayer( (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer14): DenseLayer( (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer15): DenseLayer( (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer16): DenseLayer( (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer17): DenseLayer( (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer18): DenseLayer( (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer19): DenseLayer( (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer20): DenseLayer( (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer21): DenseLayer( (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer22): DenseLayer( (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer23): DenseLayer( (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer24): DenseLayer( (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition3): Transition( (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock4): DenseBlock( (denselayer1): DenseLayer( (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): DenseLayer( (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): DenseLayer( (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): DenseLayer( (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): DenseLayer( (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): DenseLayer( (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): DenseLayer( (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): DenseLayer( (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): DenseLayer( (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): DenseLayer( (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): DenseLayer( (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): DenseLayer( (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer13): DenseLayer( (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer14): DenseLayer( (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer15): DenseLayer( (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer16): DenseLayer( (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu5): ReLU(inplace=True) ) (classifier): Linear(in_features=1024, out_features=2, bias=True) )
# 统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))
代码输出:
Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 112, 112] 9,408 BatchNorm2d-2 [-1, 64, 112, 112] 128 ReLU-3 [-1, 64, 112, 112] 0 MaxPool2d-4 [-1, 64, 56, 56] 0 BatchNorm2d-5 [-1, 64, 56, 56] 128 ReLU-6 [-1, 64, 56, 56] 0 Conv2d-7 [-1, 128, 56, 56] 8,192 BatchNorm2d-8 [-1, 128, 56, 56] 256 ReLU-9 [-1, 128, 56, 56] 0 Conv2d-10 [-1, 32, 56, 56] 36,864 BatchNorm2d-11 [-1, 96, 56, 56] 192 ReLU-12 [-1, 96, 56, 56] 0 Conv2d-13 [-1, 128, 56, 56] 12,288 BatchNorm2d-14 [-1, 128, 56, 56] 256 ReLU-15 [-1, 128, 56, 56] 0 Conv2d-16 [-1, 32, 56, 56] 36,864 BatchNorm2d-17 [-1, 128, 56, 56] 256 ReLU-18 [-1, 128, 56, 56] 0 Conv2d-19 [-1, 128, 56, 56] 16,384 BatchNorm2d-20 [-1, 128, 56, 56] 256 ReLU-21 [-1, 128, 56, 56] 0 Conv2d-22 [-1, 32, 56, 56] 36,864 BatchNorm2d-23 [-1, 160, 56, 56] 320 ReLU-24 [-1, 160, 56, 56] 0 Conv2d-25 [-1, 128, 56, 56] 20,480 BatchNorm2d-26 [-1, 128, 56, 56] 256 ReLU-27 [-1, 128, 56, 56] 0 Conv2d-28 [-1, 32, 56, 56] 36,864 BatchNorm2d-29 [-1, 192, 56, 56] 384 ReLU-30 [-1, 192, 56, 56] 0 Conv2d-31 [-1, 128, 56, 56] 24,576 BatchNorm2d-32 [-1, 128, 56, 56] 256 ReLU-33 [-1, 128, 56, 56] 0 Conv2d-34 [-1, 32, 56, 56] 36,864 BatchNorm2d-35 [-1, 224, 56, 56] 448 ReLU-36 [-1, 224, 56, 56] 0 Conv2d-37 [-1, 128, 56, 56] 28,672 BatchNorm2d-38 [-1, 128, 56, 56] 256 ReLU-39 [-1, 128, 56, 56] 0 Conv2d-40 [-1, 32, 56, 56] 36,864 BatchNorm2d-41 [-1, 256, 56, 56] 512 ReLU-42 [-1, 256, 56, 56] 0 Conv2d-43 [-1, 128, 56, 56] 32,768 AvgPool2d-44 [-1, 128, 28, 28] 0 BatchNorm2d-45 [-1, 128, 28, 28] 256 ReLU-46 [-1, 128, 28, 28] 0 Conv2d-47 [-1, 128, 28, 28] 16,384 BatchNorm2d-48 [-1, 128, 28, 28] 256 ReLU-49 [-1, 128, 28, 28] 0 Conv2d-50 [-1, 32, 28, 28] 36,864 BatchNorm2d-51 [-1, 160, 28, 28] 320 ReLU-52 [-1, 160, 28, 28] 0 Conv2d-53 [-1, 128, 28, 28] 20,480 BatchNorm2d-54 [-1, 128, 28, 28] 256 ReLU-55 [-1, 128, 28, 28] 0 Conv2d-56 [-1, 32, 28, 28] 36,864 BatchNorm2d-57 [-1, 192, 28, 28] 384 ReLU-58 [-1, 192, 28, 28] 0 Conv2d-59 [-1, 128, 28, 28] 24,576 BatchNorm2d-60 [-1, 128, 28, 28] 256 ReLU-61 [-1, 128, 28, 28] 0 Conv2d-62 [-1, 32, 28, 28] 36,864 BatchNorm2d-63 [-1, 224, 28, 28] 448 ReLU-64 [-1, 224, 28, 28] 0 Conv2d-65 [-1, 128, 28, 28] 28,672 BatchNorm2d-66 [-1, 128, 28, 28] 256 ReLU-67 [-1, 128, 28, 28] 0 Conv2d-68 [-1, 32, 28, 28] 36,864 BatchNorm2d-69 [-1, 256, 28, 28] 512 ReLU-70 [-1, 256, 28, 28] 0 Conv2d-71 [-1, 128, 28, 28] 32,768 BatchNorm2d-72 [-1, 128, 28, 28] 256 ReLU-73 [-1, 128, 28, 28] 0 Conv2d-74 [-1, 32, 28, 28] 36,864 BatchNorm2d-75 [-1, 288, 28, 28] 576 ReLU-76 [-1, 288, 28, 28] 0 Conv2d-77 [-1, 128, 28, 28] 36,864 BatchNorm2d-78 [-1, 128, 28, 28] 256 ReLU-79 [-1, 128, 28, 28] 0 Conv2d-80 [-1, 32, 28, 28] 36,864 BatchNorm2d-81 [-1, 320, 28, 28] 640 ReLU-82 [-1, 320, 28, 28] 0 Conv2d-83 [-1, 128, 28, 28] 40,960 BatchNorm2d-84 [-1, 128, 28, 28] 256 ReLU-85 [-1, 128, 28, 28] 0 Conv2d-86 [-1, 32, 28, 28] 36,864 BatchNorm2d-87 [-1, 352, 28, 28] 704 ReLU-88 [-1, 352, 28, 28] 0 Conv2d-89 [-1, 128, 28, 28] 45,056 BatchNorm2d-90 [-1, 128, 28, 28] 256 ReLU-91 [-1, 128, 28, 28] 0 Conv2d-92 [-1, 32, 28, 28] 36,864 BatchNorm2d-93 [-1, 384, 28, 28] 768 ReLU-94 [-1, 384, 28, 28] 0 Conv2d-95 [-1, 128, 28, 28] 49,152 BatchNorm2d-96 [-1, 128, 28, 28] 256 ReLU-97 [-1, 128, 28, 28] 0 Conv2d-98 [-1, 32, 28, 28] 36,864 BatchNorm2d-99 [-1, 416, 28, 28] 832 ReLU-100 [-1, 416, 28, 28] 0 Conv2d-101 [-1, 128, 28, 28] 53,248 BatchNorm2d-102 [-1, 128, 28, 28] 256 ReLU-103 [-1, 128, 28, 28] 0 Conv2d-104 [-1, 32, 28, 28] 36,864 BatchNorm2d-105 [-1, 448, 28, 28] 896 ReLU-106 [-1, 448, 28, 28] 0 Conv2d-107 [-1, 128, 28, 28] 57,344 BatchNorm2d-108 [-1, 128, 28, 28] 256 ReLU-109 [-1, 128, 28, 28] 0 Conv2d-110 [-1, 32, 28, 28] 36,864 BatchNorm2d-111 [-1, 480, 28, 28] 960 ReLU-112 [-1, 480, 28, 28] 0 Conv2d-113 [-1, 128, 28, 28] 61,440 BatchNorm2d-114 [-1, 128, 28, 28] 256 ReLU-115 [-1, 128, 28, 28] 0 Conv2d-116 [-1, 32, 28, 28] 36,864 BatchNorm2d-117 [-1, 512, 28, 28] 1,024 ReLU-118 [-1, 512, 28, 28] 0 Conv2d-119 [-1, 256, 28, 28] 131,072 AvgPool2d-120 [-1, 256, 14, 14] 0 BatchNorm2d-121 [-1, 256, 14, 14] 512 ReLU-122 [-1, 256, 14, 14] 0 Conv2d-123 [-1, 128, 14, 14] 32,768 BatchNorm2d-124 [-1, 128, 14, 14] 256 ReLU-125 [-1, 128, 14, 14] 0 Conv2d-126 [-1, 32, 14, 14] 36,864 BatchNorm2d-127 [-1, 288, 14, 14] 576 ReLU-128 [-1, 288, 14, 14] 0 Conv2d-129 [-1, 128, 14, 14] 36,864 BatchNorm2d-130 [-1, 128, 14, 14] 256 ReLU-131 [-1, 128, 14, 14] 0 Conv2d-132 [-1, 32, 14, 14] 36,864 BatchNorm2d-133 [-1, 320, 14, 14] 640 ReLU-134 [-1, 320, 14, 14] 0 Conv2d-135 [-1, 128, 14, 14] 40,960 BatchNorm2d-136 [-1, 128, 14, 14] 256 ReLU-137 [-1, 128, 14, 14] 0 Conv2d-138 [-1, 32, 14, 14] 36,864 BatchNorm2d-139 [-1, 352, 14, 14] 704 ReLU-140 [-1, 352, 14, 14] 0 Conv2d-141 [-1, 128, 14, 14] 45,056 BatchNorm2d-142 [-1, 128, 14, 14] 256 ReLU-143 [-1, 128, 14, 14] 0 Conv2d-144 [-1, 32, 14, 14] 36,864 BatchNorm2d-145 [-1, 384, 14, 14] 768 ReLU-146 [-1, 384, 14, 14] 0 Conv2d-147 [-1, 128, 14, 14] 49,152 BatchNorm2d-148 [-1, 128, 14, 14] 256 ReLU-149 [-1, 128, 14, 14] 0 Conv2d-150 [-1, 32, 14, 14] 36,864 BatchNorm2d-151 [-1, 416, 14, 14] 832 ReLU-152 [-1, 416, 14, 14] 0 Conv2d-153 [-1, 128, 14, 14] 53,248 BatchNorm2d-154 [-1, 128, 14, 14] 256 ReLU-155 [-1, 128, 14, 14] 0 Conv2d-156 [-1, 32, 14, 14] 36,864 BatchNorm2d-157 [-1, 448, 14, 14] 896 ReLU-158 [-1, 448, 14, 14] 0 Conv2d-159 [-1, 128, 14, 14] 57,344 BatchNorm2d-160 [-1, 128, 14, 14] 256 ReLU-161 [-1, 128, 14, 14] 0 Conv2d-162 [-1, 32, 14, 14] 36,864 BatchNorm2d-163 [-1, 480, 14, 14] 960 ReLU-164 [-1, 480, 14, 14] 0 Conv2d-165 [-1, 128, 14, 14] 61,440 BatchNorm2d-166 [-1, 128, 14, 14] 256 ReLU-167 [-1, 128, 14, 14] 0 Conv2d-168 [-1, 32, 14, 14] 36,864 BatchNorm2d-169 [-1, 512, 14, 14] 1,024 ReLU-170 [-1, 512, 14, 14] 0 Conv2d-171 [-1, 128, 14, 14] 65,536 BatchNorm2d-172 [-1, 128, 14, 14] 256 ReLU-173 [-1, 128, 14, 14] 0 Conv2d-174 [-1, 32, 14, 14] 36,864 BatchNorm2d-175 [-1, 544, 14, 14] 1,088 ReLU-176 [-1, 544, 14, 14] 0 Conv2d-177 [-1, 128, 14, 14] 69,632 BatchNorm2d-178 [-1, 128, 14, 14] 256 ReLU-179 [-1, 128, 14, 14] 0 Conv2d-180 [-1, 32, 14, 14] 36,864 BatchNorm2d-181 [-1, 576, 14, 14] 1,152 ReLU-182 [-1, 576, 14, 14] 0 Conv2d-183 [-1, 128, 14, 14] 73,728 BatchNorm2d-184 [-1, 128, 14, 14] 256 ReLU-185 [-1, 128, 14, 14] 0 Conv2d-186 [-1, 32, 14, 14] 36,864 BatchNorm2d-187 [-1, 608, 14, 14] 1,216 ReLU-188 [-1, 608, 14, 14] 0 Conv2d-189 [-1, 128, 14, 14] 77,824 BatchNorm2d-190 [-1, 128, 14, 14] 256 ReLU-191 [-1, 128, 14, 14] 0 Conv2d-192 [-1, 32, 14, 14] 36,864 BatchNorm2d-193 [-1, 640, 14, 14] 1,280 ReLU-194 [-1, 640, 14, 14] 0 Conv2d-195 [-1, 128, 14, 14] 81,920 BatchNorm2d-196 [-1, 128, 14, 14] 256 ReLU-197 [-1, 128, 14, 14] 0 Conv2d-198 [-1, 32, 14, 14] 36,864 BatchNorm2d-199 [-1, 672, 14, 14] 1,344 ReLU-200 [-1, 672, 14, 14] 0 Conv2d-201 [-1, 128, 14, 14] 86,016 BatchNorm2d-202 [-1, 128, 14, 14] 256 ReLU-203 [-1, 128, 14, 14] 0 Conv2d-204 [-1, 32, 14, 14] 36,864 BatchNorm2d-205 [-1, 704, 14, 14] 1,408 ReLU-206 [-1, 704, 14, 14] 0 Conv2d-207 [-1, 128, 14, 14] 90,112 BatchNorm2d-208 [-1, 128, 14, 14] 256 ReLU-209 [-1, 128, 14, 14] 0 Conv2d-210 [-1, 32, 14, 14] 36,864 BatchNorm2d-211 [-1, 736, 14, 14] 1,472 ReLU-212 [-1, 736, 14, 14] 0 Conv2d-213 [-1, 128, 14, 14] 94,208 BatchNorm2d-214 [-1, 128, 14, 14] 256 ReLU-215 [-1, 128, 14, 14] 0 Conv2d-216 [-1, 32, 14, 14] 36,864 BatchNorm2d-217 [-1, 768, 14, 14] 1,536 ReLU-218 [-1, 768, 14, 14] 0 Conv2d-219 [-1, 128, 14, 14] 98,304 BatchNorm2d-220 [-1, 128, 14, 14] 256 ReLU-221 [-1, 128, 14, 14] 0 Conv2d-222 [-1, 32, 14, 14] 36,864 BatchNorm2d-223 [-1, 800, 14, 14] 1,600 ReLU-224 [-1, 800, 14, 14] 0 Conv2d-225 [-1, 128, 14, 14] 102,400 BatchNorm2d-226 [-1, 128, 14, 14] 256 ReLU-227 [-1, 128, 14, 14] 0 Conv2d-228 [-1, 32, 14, 14] 36,864 BatchNorm2d-229 [-1, 832, 14, 14] 1,664 ReLU-230 [-1, 832, 14, 14] 0 Conv2d-231 [-1, 128, 14, 14] 106,496 BatchNorm2d-232 [-1, 128, 14, 14] 256 ReLU-233 [-1, 128, 14, 14] 0 Conv2d-234 [-1, 32, 14, 14] 36,864 BatchNorm2d-235 [-1, 864, 14, 14] 1,728 ReLU-236 [-1, 864, 14, 14] 0 Conv2d-237 [-1, 128, 14, 14] 110,592 BatchNorm2d-238 [-1, 128, 14, 14] 256 ReLU-239 [-1, 128, 14, 14] 0 Conv2d-240 [-1, 32, 14, 14] 36,864 BatchNorm2d-241 [-1, 896, 14, 14] 1,792 ReLU-242 [-1, 896, 14, 14] 0 Conv2d-243 [-1, 128, 14, 14] 114,688 BatchNorm2d-244 [-1, 128, 14, 14] 256 ReLU-245 [-1, 128, 14, 14] 0 Conv2d-246 [-1, 32, 14, 14] 36,864 BatchNorm2d-247 [-1, 928, 14, 14] 1,856 ReLU-248 [-1, 928, 14, 14] 0 Conv2d-249 [-1, 128, 14, 14] 118,784 BatchNorm2d-250 [-1, 128, 14, 14] 256 ReLU-251 [-1, 128, 14, 14] 0 Conv2d-252 [-1, 32, 14, 14] 36,864 BatchNorm2d-253 [-1, 960, 14, 14] 1,920 ReLU-254 [-1, 960, 14, 14] 0 Conv2d-255 [-1, 128, 14, 14] 122,880 BatchNorm2d-256 [-1, 128, 14, 14] 256 ReLU-257 [-1, 128, 14, 14] 0 Conv2d-258 [-1, 32, 14, 14] 36,864 BatchNorm2d-259 [-1, 992, 14, 14] 1,984 ReLU-260 [-1, 992, 14, 14] 0 Conv2d-261 [-1, 128, 14, 14] 126,976 BatchNorm2d-262 [-1, 128, 14, 14] 256 ReLU-263 [-1, 128, 14, 14] 0 Conv2d-264 [-1, 32, 14, 14] 36,864 BatchNorm2d-265 [-1, 1024, 14, 14] 2,048 ReLU-266 [-1, 1024, 14, 14] 0 Conv2d-267 [-1, 512, 14, 14] 524,288 AvgPool2d-268 [-1, 512, 7, 7] 0 BatchNorm2d-269 [-1, 512, 7, 7] 1,024 ReLU-270 [-1, 512, 7, 7] 0 Conv2d-271 [-1, 128, 7, 7] 65,536 BatchNorm2d-272 [-1, 128, 7, 7] 256 ReLU-273 [-1, 128, 7, 7] 0 Conv2d-274 [-1, 32, 7, 7] 36,864 BatchNorm2d-275 [-1, 544, 7, 7] 1,088 ReLU-276 [-1, 544, 7, 7] 0 Conv2d-277 [-1, 128, 7, 7] 69,632 BatchNorm2d-278 [-1, 128, 7, 7] 256 ReLU-279 [-1, 128, 7, 7] 0 Conv2d-280 [-1, 32, 7, 7] 36,864 BatchNorm2d-281 [-1, 576, 7, 7] 1,152 ReLU-282 [-1, 576, 7, 7] 0 Conv2d-283 [-1, 128, 7, 7] 73,728 BatchNorm2d-284 [-1, 128, 7, 7] 256 ReLU-285 [-1, 128, 7, 7] 0 Conv2d-286 [-1, 32, 7, 7] 36,864 BatchNorm2d-287 [-1, 608, 7, 7] 1,216 ReLU-288 [-1, 608, 7, 7] 0 Conv2d-289 [-1, 128, 7, 7] 77,824 BatchNorm2d-290 [-1, 128, 7, 7] 256 ReLU-291 [-1, 128, 7, 7] 0 Conv2d-292 [-1, 32, 7, 7] 36,864 BatchNorm2d-293 [-1, 640, 7, 7] 1,280 ReLU-294 [-1, 640, 7, 7] 0 Conv2d-295 [-1, 128, 7, 7] 81,920 BatchNorm2d-296 [-1, 128, 7, 7] 256 ReLU-297 [-1, 128, 7, 7] 0 Conv2d-298 [-1, 32, 7, 7] 36,864 BatchNorm2d-299 [-1, 672, 7, 7] 1,344 ReLU-300 [-1, 672, 7, 7] 0 Conv2d-301 [-1, 128, 7, 7] 86,016 BatchNorm2d-302 [-1, 128, 7, 7] 256 ReLU-303 [-1, 128, 7, 7] 0 Conv2d-304 [-1, 32, 7, 7] 36,864 BatchNorm2d-305 [-1, 704, 7, 7] 1,408 ReLU-306 [-1, 704, 7, 7] 0 Conv2d-307 [-1, 128, 7, 7] 90,112 BatchNorm2d-308 [-1, 128, 7, 7] 256 ReLU-309 [-1, 128, 7, 7] 0 Conv2d-310 [-1, 32, 7, 7] 36,864 BatchNorm2d-311 [-1, 736, 7, 7] 1,472 ReLU-312 [-1, 736, 7, 7] 0 Conv2d-313 [-1, 128, 7, 7] 94,208 BatchNorm2d-314 [-1, 128, 7, 7] 256 ReLU-315 [-1, 128, 7, 7] 0 Conv2d-316 [-1, 32, 7, 7] 36,864 BatchNorm2d-317 [-1, 768, 7, 7] 1,536 ReLU-318 [-1, 768, 7, 7] 0 Conv2d-319 [-1, 128, 7, 7] 98,304 BatchNorm2d-320 [-1, 128, 7, 7] 256 ReLU-321 [-1, 128, 7, 7] 0 Conv2d-322 [-1, 32, 7, 7] 36,864 BatchNorm2d-323 [-1, 800, 7, 7] 1,600 ReLU-324 [-1, 800, 7, 7] 0 Conv2d-325 [-1, 128, 7, 7] 102,400 BatchNorm2d-326 [-1, 128, 7, 7] 256 ReLU-327 [-1, 128, 7, 7] 0 Conv2d-328 [-1, 32, 7, 7] 36,864 BatchNorm2d-329 [-1, 832, 7, 7] 1,664 ReLU-330 [-1, 832, 7, 7] 0 Conv2d-331 [-1, 128, 7, 7] 106,496 BatchNorm2d-332 [-1, 128, 7, 7] 256 ReLU-333 [-1, 128, 7, 7] 0 Conv2d-334 [-1, 32, 7, 7] 36,864 BatchNorm2d-335 [-1, 864, 7, 7] 1,728 ReLU-336 [-1, 864, 7, 7] 0 Conv2d-337 [-1, 128, 7, 7] 110,592 BatchNorm2d-338 [-1, 128, 7, 7] 256 ReLU-339 [-1, 128, 7, 7] 0 Conv2d-340 [-1, 32, 7, 7] 36,864 BatchNorm2d-341 [-1, 896, 7, 7] 1,792 ReLU-342 [-1, 896, 7, 7] 0 Conv2d-343 [-1, 128, 7, 7] 114,688 BatchNorm2d-344 [-1, 128, 7, 7] 256 ReLU-345 [-1, 128, 7, 7] 0 Conv2d-346 [-1, 32, 7, 7] 36,864 BatchNorm2d-347 [-1, 928, 7, 7] 1,856 ReLU-348 [-1, 928, 7, 7] 0 Conv2d-349 [-1, 128, 7, 7] 118,784 BatchNorm2d-350 [-1, 128, 7, 7] 256 ReLU-351 [-1, 128, 7, 7] 0 Conv2d-352 [-1, 32, 7, 7] 36,864 BatchNorm2d-353 [-1, 960, 7, 7] 1,920 ReLU-354 [-1, 960, 7, 7] 0 Conv2d-355 [-1, 128, 7, 7] 122,880 BatchNorm2d-356 [-1, 128, 7, 7] 256 ReLU-357 [-1, 128, 7, 7] 0 Conv2d-358 [-1, 32, 7, 7] 36,864 BatchNorm2d-359 [-1, 992, 7, 7] 1,984 ReLU-360 [-1, 992, 7, 7] 0 Conv2d-361 [-1, 128, 7, 7] 126,976 BatchNorm2d-362 [-1, 128, 7, 7] 256 ReLU-363 [-1, 128, 7, 7] 0 Conv2d-364 [-1, 32, 7, 7] 36,864 BatchNorm2d-365 [-1, 1024, 7, 7] 2,048 ReLU-366 [-1, 1024, 7, 7] 0 Linear-367 [-1, 2] 2,050 ================================================================ Total params: 6,955,906 Trainable params: 6,955,906 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 294.57 Params size (MB): 26.53 Estimated Total Size (MB): 321.68 ----------------------------------------------------------------
四、训练模型
1、编写训练函数
def train(dataloader,model,loss_fn,optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_acc,train_loss = 0,0
for X,y in dataloader:
X,y = X.to(device),y.to(device)
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss /= num_batches
train_acc /= size
return train_acc,train_loss
2、编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset) # 测试集的大小
num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
test_loss, test_acc = 0, 0
# 当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
# 计算loss
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
3、正式训练
import copy
loss_fn = nn.CrossEntropyLoss()
learn_rate = 1e-4
# SGD与Adam优化器,选择其中一个
# opt = torch.optim.SGD(model.parameters(),lr=learn_rate)
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
scheduler=torch.optim.lr_scheduler.StepLR(opt,step_size=1,gamma=0.9) #定义学习率高度器
epochs = 100 #设置训练模型的最大轮数为100,但可能到不了100
patience=10 #早停的耐心值,即如果模型连续10个周期没有准确率提升,则跳出训练
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc = 0 #设置一个最佳的准确率,作为最佳模型的判别指标
no_improve_epoch=0 #用于跟踪准确率是否提升的计数器
epoch=0 #用于统计最终的训练模型的轮数,这里设置初始值为0;为绘图作准备,这里的绘图范围不是epochs = 100
#开始训练
for epoch in range(epochs):
model.train()
epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)
model.eval()
epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
no_improve_epoch=0 #重置计数器
#保存最佳模型的检查点
PATH='./J3-1_best_model(j3-1).pth'
torch.save({
'epoch':epoch,
'model_state_dict':best_model.state_dict(),
'optimizer_state_dict':opt.state_dict(),
'loss':epoch_test_loss,
},PATH)
else:
no_improve_epoch += 1
if no_improve_epoch >= patience:
print(f"Early stoping triggered at epoch {epoch+1}")
break #早停
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
scheduler.step() #更新学习率
lr = opt.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH='./j3-1_best_model_.pth' #保存的参数文件名
torch.save(best_model.state_dict(),PATH)
print('Done')
print(epoch)
print('no_improve_epoch:',no_improve_epoch)
代码输出:
Epoch: 1, Train_acc:81.5%, Train_loss:0.418, Test_acc:82.6%,Test_loss:0.407, Lr:9.00E-05 Epoch: 2, Train_acc:85.2%, Train_loss:0.353, Test_acc:85.5%, Test_loss:0.362, Lr:8.10E-05 Epoch: 3, Train_acc:87.2%, Train_loss:0.315, Test_acc:85.3%, Test_loss:0.335, Lr:7.29E-05 Epoch: 4, Train_acc:88.1%, Train_loss:0.295, Test_acc:89.7%, Test_loss:0.252, Lr:6.56E-05 Epoch: 5, Train_acc:89.2%, Train_loss:0.265, Test_acc:88.9%, Test_loss:0.253, Lr:5.90E-05 Epoch: 6, Train_acc:89.1%, Train_loss:0.261, Test_acc:90.5%, Test_loss:0.248, Lr:5.31E-05 Epoch: 7, Train_acc:90.3%, Train_loss:0.242, Test_acc:89.1%, Test_loss:0.273, Lr:4.78E-05 Epoch: 8, Train_acc:90.8%, Train_loss:0.228, Test_acc:91.3%, Test_loss:0.238, Lr:4.30E-05 Epoch: 9, Train_acc:91.2%, Train_loss:0.224, Test_acc:91.2%, Test_loss:0.242, Lr:3.87E-05 Epoch:10, Train_acc:91.6%, Train_loss:0.214, Test_acc:91.2%, Test_loss:0.234, Lr:3.49E-05 Epoch:11, Train_acc:91.7%, Train_loss:0.207, Test_acc:91.5%, Test_loss:0.226, Lr:3.14E-05 Epoch:12, Train_acc:92.4%, Train_loss:0.198, Test_acc:91.5%, Test_loss:0.217, Lr:2.82E-05 Epoch:13, Train_acc:92.4%, Train_loss:0.189, Test_acc:89.6%, Test_loss:0.259, Lr:2.54E-05 Epoch:14, Train_acc:92.7%, Train_loss:0.180, Test_acc:91.5%, Test_loss:0.233, Lr:2.29E-05 Epoch:15, Train_acc:93.3%, Train_loss:0.171, Test_acc:91.0%, Test_loss:0.242, Lr:2.06E-05 Epoch:16, Train_acc:93.7%, Train_loss:0.163, Test_acc:91.6%, Test_loss:0.220, Lr:1.85E-05 Epoch:17, Train_acc:93.7%, Train_loss:0.158, Test_acc:92.1%, Test_loss:0.209, Lr:1.67E-05 Epoch:18, Train_acc:94.1%, Train_loss:0.148, Test_acc:92.3%, Test_loss:0.216, Lr:1.50E-05 Epoch:19, Train_acc:94.4%, Train_loss:0.145, Test_acc:91.9%, Test_loss:0.226, Lr:1.35E-05 Epoch:20, Train_acc:94.8%, Train_loss:0.135, Test_acc:90.7%, Test_loss:0.257, Lr:1.22E-05 Epoch:21, Train_acc:94.9%, Train_loss:0.129, Test_acc:90.9%, Test_loss:0.232, Lr:1.09E-05 Epoch:22, Train_acc:95.5%, Train_loss:0.121, Test_acc:91.6%, Test_loss:0.234, Lr:9.85E-06 Epoch:23, Train_acc:95.2%, Train_loss:0.123, Test_acc:92.4%, Test_loss:0.204, Lr:8.86E-06 Epoch:24, Train_acc:95.7%, Train_loss:0.114, Test_acc:92.2%, Test_loss:0.203, Lr:7.98E-06 Epoch:25, Train_acc:95.8%, Train_loss:0.109, Test_acc:91.6%, Test_loss:0.218, Lr:7.18E-06 Epoch:26, Train_acc:95.9%, Train_loss:0.108, Test_acc:90.9%, Test_loss:0.271, Lr:6.46E-06 Epoch:27, Train_acc:95.9%, Train_loss:0.109, Test_acc:91.1%, Test_loss:0.251, Lr:5.81E-06 Epoch:28, Train_acc:96.0%, Train_loss:0.103, Test_acc:91.5%, Test_loss:0.240, Lr:5.23E-06 Epoch:29, Train_acc:96.5%, Train_loss:0.097, Test_acc:92.7%, Test_loss:0.200, Lr:4.71E-06 Epoch:30, Train_acc:96.5%, Train_loss:0.093, Test_acc:91.6%, Test_loss:0.242, Lr:4.24E-06 Epoch:31, Train_acc:96.5%, Train_loss:0.092, Test_acc:92.1%, Test_loss:0.228, Lr:3.82E-06 Epoch:32, Train_acc:96.9%, Train_loss:0.086, Test_acc:91.4%, Test_loss:0.234, Lr:3.43E-06 Epoch:33, Train_acc:97.1%, Train_loss:0.081, Test_acc:92.1%, Test_loss:0.225, Lr:3.09E-06 Epoch:34, Train_acc:97.2%, Train_loss:0.078, Test_acc:92.1%, Test_loss:0.226, Lr:2.78E-06 Epoch:35, Train_acc:97.0%, Train_loss:0.082, Test_acc:91.8%, Test_loss:0.229, Lr:2.50E-06 Epoch:36, Train_acc:97.4%, Train_loss:0.076, Test_acc:91.1%, Test_loss:0.251, Lr:2.25E-06 Epoch:37, Train_acc:97.3%, Train_loss:0.075, Test_acc:91.4%, Test_loss:0.238, Lr:2.03E-06 Epoch:38, Train_acc:97.1%, Train_loss:0.078, Test_acc:90.4%, Test_loss:0.276, Lr:1.82E-06 Early stoping triggered at epoch 39 Done 38 no_improve_epoch: 10
五、结果可视化
1、Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率
epochs_range = range(epoch)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
代码输出:
2、预测
from PIL import Image
classes = list(total_data.class_to_idx)
def predict_one_image(image_path, model, transform, classes):
test_img = Image.open(image_path).convert('RGB')
plt.imshow(test_img) # 展示预测的图片
test_img = transform(test_img)
img = test_img.to(device).unsqueeze(0)
model.eval()
output = model(img)
_,pred = torch.max(output,1)
pred_class = classes[pred]
print(f'预测结果是:{pred_class}')
import os
from pathlib import Path
import random
image=[]
def image_path(data_dir):
file_list=os.listdir(data_dir) #列出四个分类标签
data_file_dir=file_list #从四个分类标签中随机选择一个
data_dir=Path(data_dir)
for i in data_file_dir:
i=Path(i)
image_file_path=data_dir.joinpath(i) #拼接路径
data_file_paths=image_file_path.iterdir() #罗列文件夹的内容
data_file_paths=list(data_file_paths) #要转换为列表
image.append(data_file_paths)
file=random.choice(image) #从所有的图像中随机选择一类
file=random.choice(file) #从选择的类中随机选择一张图片
return file
data_dir='J3-1/'
image_path=image_path(data_dir)
image_path
代码输出:
WindowsPath('J3-1/1/9075_idx5_x1251_y251_class1.png')
# 预测训练集中的某张照片
predict_one_image(image_path=image_path,
model=model,
transform=train_transforms,
classes=classes)
代码输出:
预测结果是:1
六、总结
相比以往的图像识别,此次训练的时间用时差不多三天,应该是训练集太大的原因,本次的训练图片有13403张。