一. ShuffleNet V2 神经网络介绍
ShuffleNet V2 是对 ShuffleNet 的改进版本,旨在解决其前代的一些问题,例如低分辨率和通道混洗效率低下等。以下是 ShuffleNet V2 的一些关键特点:
ShuffleNet V2 特点
- 优化的分组卷积:
- ShuffleNet V2 使用了一种称为“channel split”的技术,该技术将输入通道分成两半,分别进行不同的处理,然后合并结果以获得更好的性能。
- 自适应分组卷积:
- ShuffleNet V2 根据输入数据动态调整分组数量,以实现更高的效率。
- 多尺度特征融合:
- ShuffleNet V2 引入了多尺度特征融合模块,以更好地捕捉不同尺度的特征。
- 通道剪枝:
- ShuffleNet V2 应用通道剪枝策略来进一步减少计算复杂度,同时保持准确性。
ShuffleNet V2 结构
ShuffleNet V2 仍然基于分组卷积和通道混洗,但引入了一些新的设计决策来改善性能和效率。例如,它使用了自适应分组卷积,根据输入数据自动选择合适的分组数量。此外,它还增加了多尺度特征融合模块,以捕获不同尺度的特征,这对于识别复杂的对象和场景非常有用。
ShuffleNet V2 单元
ShuffleNet V2 单元包含以下组件:
- 通道分割:
- 输入通道被分割成两部分,每部分独立处理。
- 分组卷积:
- 分割后的通道各自执行分组卷积。
- 通道混洗:
- 混洗操作用于打破分组卷积中的信息隔离。
- 多尺度特征融合:
- 不同尺度的特征被融合在一起,以增强模型的表现力。
- 通道剪枝:
- 通过剪枝策略减少不必要的通道,以降低计算复杂度。
总结
ShuffleNet V2 在保留了原版 ShuffleNet 的高效性的同时,通过引入自适应分组卷积、多尺度特征融合和通道剪枝等方法,提高了模型的性能和灵活性。这种改进使得 ShuffleNet V2 更适用于资源受限的设备,如手机和嵌入式系统。
二. ShuffleNet V2 神经网络细节
ShuffleNet V2的设计考虑了多个因素以提高计算效率和性能。以下是更详细的介绍:
1. ShuffleNet V2 设计准则
G1:内存访问成本最小化
ShuffleNet V2 试图最小化内存访问成本(MAC)。在深度可分离卷积中,
1
×
1
1×1
1×1 卷积的 FLOPs 为
B
=
h
w
c
1
c
2
B = hwc_{1}c_{2}
B=hwc1c2 ,内存访问成本为:
B
=
h
w
(
c
1
+
c
2
)
+
c
1
c
2
B = hw(c_{1}+c_{2})+c_{1}c_{2}
B=hw(c1+c2)+c1c2 根据均值不等式,MAC 至少为:
2
h
w
B
+
B
h
w
2\sqrt{hwB}+\frac{B}{hw}
2hwB+hwB 只有当
c
1
=
c
2
c_{1} = c_{2}
c1=c2 时,MAC 才能取到最小值。
G2:避免过多使用分组卷积
对于分组卷积,FLOPs 为
B
=
h
w
c
1
c
2
/
g
B = hwc_{1}c_{2}/g
B=hwc1c2/g ,MAC 为:
h
w
(
c
1
+
c
2
)
+
c
1
c
2
/
g
hw(c_{1}+c_{2})+c_{1}c_{2}/g
hw(c1+c2)+c1c2/g 如果固定输入
c
1
×
h
×
w
c_{1}\times h \times w
c1×h×w 和
B
B
B ,则 MAC 为:
h
w
c
1
+
B
g
/
c
1
+
B
/
h
w
hwc_{1}+Bg/c_{1}+B/hw
hwc1+Bg/c1+B/hw随着分组数
g
g
g 的增加,MAC 也会增加。
G3:避免网络碎片化
ShuffleNet V2 避免使用多路结构,因为它会导致网络碎片化,从而减慢速度。
G4:重视元素级操作
ShuffleNet V2 注意到 ReLU、TensorAdd 和 BiasAdd 等元素级操作虽然 FLOPs 少,但 MAC 大。实验表明,移除残差网络中的 ReLU 和短接可以提高 20% 的速度。
因此,基于上述分析,作者得到了4条指导准则:
- 使用“平衡”卷积层,即输入与输出通道相同;
- 谨慎使用分组卷积并注意分组数;
- 减少碎片化的操作;
- 减少元素级的操作。
2. ShuffleNet V2 Block
下图中的 a 与 b 是 ShuffleNet V1 的 architecture,c 与 d 是 ShuffleNet V2 的 architecture。仔细观察我们可以发现ShuffleNet V1中到处违背了4条设计原则:
- 使用了 bottleneck layer,使得输入输出通道数不同,违背了 G1 原则;
- 大量使用 1 × 1 1 × 1 1×1 卷积,违背了 G2 原则;
- 使用了过多的 group ,违背了 G3 原则;
- shortcut 中存在了大量的元素级 add 运算,违背了 G4 原则。
ShuffleNet V2 是对ShuffleNet V1 的改进版本,旨在解决前者的几个问题。以下是针对所提到的几点的详细分析:
- 通道拆分(Channel Split):
- ShuffleNet V2 引入了一个新的操作——通道拆分,如图 C 和 D 所示。这将输入通道分成两个分支,其中一个分支执行恒等映射,保持输入和输出通道数不变,符合 G3 原则,即避免网络碎片化。另一个分支执行多层卷积,以确保输入和输出通道数相等,符合 G1 原则,即最小化内存访问成本。
- 取消分组卷积:
- ShuffleNet V2 放弃了ShuffleNet V1 中使用的分组卷积,特别是 1x1 分组卷积,以减少内存访问成本,符合 G2 原则。
- 使用concatenate代替TensorAdd:
- ShuffleNet V2 使用 concatenate 操作而不是 TensorAdd 来合并两个分支的输出,这是因为 concatenate 操作具有更低的计算复杂度,符合 G4 原则。
通过这些改进,ShuffleNet V2 在保持计算效率的同时提高了准确率,并减少了内存访问成本。此外,它还避免了网络碎片化,降低了计算复杂度,从而提高了整体性能。
三. ShuffleNet V2 神经网络结构
ShuffleNet V2 是一种专门为移动设备优化的高效卷积神经网络,它利用通道拆分和通道级联操作来提高计算效率。下面是ShuffleNet V2 结构的详细概述:
层次结构
ShuffleNet V2 由一系列阶段(Stage)组成,每个阶段都包含多个基本单元(Basic Unit)。表中列出了各个阶段及其对应的输出大小、卷积核大小(KSize)、步长(Stride)和重复次数(Repeat)。此外,还给出了不同缩放因子(Scaling Factor)下的输出通道数。
阶段描述
-
Conv1:
- 这是第一个卷积层,用于初始化输入图像。它采用 3 × 3 3 \times 3 3×3 的卷积核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56 。
- 输出通道数为 24 24 24 ,对于所有缩放因子都是相同的。
-
MaxPool:
- 接着是一个最大池化层,同样采用 3 × 3 3 \times 3 3×3 的核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56 。
-
Stage2:
- Stage2 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 28 × 28 28 \times 28 28×28 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Stage3:
- Stage3 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 14 × 14 14 \times 14 14×14 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 7 7 7 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Stage4:
- Stage4包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 7 × 7 7 \times 7 7×7 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Conv5:
- Conv5 是一个全局池化层,使用 1 × 1 1 \times 1 1×1 的卷积核进行全局池化操作。
- 所有缩放因子的输出通道数均为 1024 1024 1024 。
-
GlobalPool and FC:
- 全局池化层将 7 × 7 7 \times 7 7×7 的特征图转换为一个向量,然后传递给全连接层进行预测。
- 对应于所有缩放因子的 FLOPs 和权重数量均列出。
ShuffleNet V2 的设计目标是在保持准确性的同时尽可能地减少计算量。它通过通道拆分、通道混洗和适当的重复次数来实现这一点。此外,它还优化了内存访问成本,避免了不必要的分组卷积,并尽量减少了网络碎片化。
四. ShuffleNet V2 代码实现
开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。
- model.py:定义网络模型
- train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
- predict.py:用自己的数据集进行分类测试
- utils.py:依赖脚本
- my_dataset.py:依赖脚本
- model.py
from typing import List,Callable
import torch
from torch import Tensor
import torch.nn as nn
# 定义channel_shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
# 获取输入x的[B, C, H, W]
batch_size, num_channels, height, width = x.size()
# 获取每个组的channel
channel_per_group = num_channels // groups
# reshape
# [B, C, H, W] -> [B, G, C, H, W]
x = x.view(batch_size, groups, channel_per_group, height, width)
# 调换维度1和维度2 -> [G, B, C, H, W]
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
class InvertedResidual(nn.Module):
# input_c:输入特征矩阵通道 output_c:输出特征矩阵通道 stride:DW卷积步幅
def __init__(self, input_c: int, output_c: int, stride: int):
super(InvertedResidual, self).__init__()
# 判断stride是否只取1和2
if stride not in [1, 2]:
raise ValueError("illegal stride value")
self.stride = stride
# 判断output_c是否为2的整数倍(结构左右分支的通道数都是相同的)
assert output_c % 2 == 0
branch_features = output_c // 2
# 当stride=1为1时,input_channel应该是branch_features的两倍
# python中 ”<<“ 是位运算,可理解为计算x2的快速方法
assert (self.stride != 1) or (input_c == branch_features << 1)
if self.stride == 2:
self.branch1 = nn.Sequential(
self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(input_c),
nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
@staticmethod
def depthwise_conv(input_c: int,
output_c: int,
kernel_s: int,
stride: int = 1,
padding: int = 0,
bias: int = False) -> nn.Conv2d:
return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
stride=stride, padding=padding, bias=bias, groups=input_c)
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
# 将channel均分成两份
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
# stage_repeats:Block重复次数
def __init__(self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual):
super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels
# input RGB image
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 声明stage实现方法
self.stage2 = nn.Sequential
self.stage3 = nn.Sequential
self.stage4 = nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(stage_names, stages_repeats,
self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
# 使用setattr(self, name, nn.Sequential(*seq))来将创建好的nn.Sequential对象设置为当前类实例的一个属性,属性名为name
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # global pool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def shufflenet_v2_x0_5(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 48, 96, 192, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x1_0(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 116, 232, 464, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x1_5(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 176, 352, 704, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x2_0(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 244, 488, 976, 2048],
num_classes=num_classes)
return model
- train.py
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
# 如果存在预训练权重则载入
model = shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
if args.weights != "":
if os.path.exists(args.weights):
weights_dict = torch.load(args.weights, map_location=device)
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
print(model.load_state_dict(load_weights_dict, strict=False))
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))
# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="E:/code/PyCharm_Projects/deep_learning/data_set/flower_data/flower_photos")
# shufflenetv2_x1.0 官方权重下载地址
# https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
# 不使用预训练权重 default=''
parser.add_argument('--weights', type=str, default='./shufflenetv2_x1.pth',
help='initial weights path')
# 冻结除最后全连接层的所有权重 default=True
parser.add_argument('--freeze-layers', type=bool, default=True)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
main(opt)
- predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import shufflenet_v2_x1_0
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = "郁金香.png"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = shufflenet_v2_x1_0(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/model-0.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
- utils.py
import os
import sys
import json
import pickle
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 排序,保证各平台顺序一致
images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
assert len(train_images_path) > 0, "number of training images must greater than 0."
assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)
def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
optimizer.zero_grad()
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
loss.backward()
mean_loss = (mean_loss * step + loss.detach()) / (step + 1) # update mean losses
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return mean_loss.item()
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
# 验证样本总个数
total_num = len(data_loader.dataset)
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
pred = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred, labels.to(device)).sum()
return sum_num.item() / total_num
- my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels