一. ShuffleNet v1 神经网络介绍
ShuffleNet V1 是一种专为移动设备设计的高效卷积神经网络(CNN),旨在解决计算资源有限的问题。它由阿里巴巴达摩院的研究人员提出,并在2018年的论文《ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices》中首次被介绍。以下是ShuffleNet V1的一些关键特点和组成部分:
关键特点
-
Group Convolution (分组卷积):
- ShuffleNet V1 使用分组卷积来减少计算成本。分组卷积将输入通道分成几个组,并在每个组内独立执行卷积操作。这类似于Inception模块中的并行分支,但更加高效。
-
Channel Shuffle (通道重排):
- 为了克服分组卷积导致的信息隔离问题,ShuffleNet V1 引入了通道重排操作。这个操作使得不同组之间的信息能够混合,从而提高了模型的表达能力。
-
Depthwise Separable Convolutions (深度可分离卷积):
- 尽管ShuffleNet V1 主要依赖于分组卷积,但它也使用深度可分离卷积作为其构建块之一,进一步减少计算量。
-
Efficiency and Accuracy Trade-off (效率与准确性的权衡):
- ShuffleNet V1 通过调整分组的数量来控制模型的复杂度,从而在计算效率和分类精度之间找到最佳平衡点。
架构组成
ShuffleNet V1 的基本单元称为 Shuffle 单元(Shuffle Unit),它通常包含两个子单元:基本的 Shuffle 单元和下采样 Shuffle 单元。
-
Basic Shuffle Unit:
- 该单元用于网络中的大部分部分,其中包含两个主要步骤:
- 第一步使用逐点卷积(1x1卷积)进行通道降维。
- 第二步包括两个并行路径:
- 分组卷积(GConv):对输入进行分组卷积。
- 直接传递:将一部分输入通道直接传递到下一个层。
- 最后,通过通道重排(Channel Shuffle)将这两个路径的输出合并起来。
- 该单元用于网络中的大部分部分,其中包含两个主要步骤:
-
Downsample Shuffle Unit:
- 当需要减小空间尺寸时,会使用下采样 Shuffle 单元。这种单元通常包含一个额外的深度可分离卷积层来进行下采样操作。
应用场景
ShuffleNet V1 特别适合于移动设备和嵌入式系统,因为它在保证一定精度的前提下大大降低了计算需求。它适用于各种计算机视觉任务,包括但不限于图像分类、目标检测和语义分割。
总结
ShuffleNet V1 是一种轻量级的卷积神经网络架构,通过使用分组卷积和通道重排技术,在保持较高分类准确率的同时,极大地降低了模型的计算复杂度和内存占用。这种设计使其成为在资源受限设备上运行的理想选择。
二. ShuffleNet v1 神经网络细节
下面是关于ShuffleNet V1神经网络中三个关键概念的详细说明:分组卷积(Group Convolution)、通道混洗(Channel Shuffle)以及ShuffleNet基础单元(ShuffleNet Unit)。
1. Group Convolution (分组卷积)
定义
分组卷积是一种高效的卷积运算方式,它将输入通道分成多个组,并在每个组内独立进行卷积运算。这种做法减少了计算量和参数数量,尤其是在处理高分辨率图像时效果显著。
工作原理
- 输入分组:
- 假设输入张量具有 C C C 个通道,我们将这些通道分成 G G G 组,每组有 C G \frac{C}{G} GC 个通道。
- 独立卷积:
- 每个组内的通道只与本组的滤波器进行卷积,而不是整个输入张量的所有通道。
- 输出拼接:
- 每个组的输出结果被拼接在一起形成最终的输出张量。
优势
- 计算效率:
- 分组卷积显著减少了计算量,因为每个滤波器只需要处理输入的一部分。
- 参数减少:
- 由于每个滤波器只作用于一部分输入通道,因此模型的参数数量也随之减少。
2. Channel Shuffle (通道混洗)
定义
通道混洗是一种特殊的操作,用于在分组卷积之后混合不同组之间的信息。它的目的是打破分组卷积导致的信息隔离,提高模型的表示能力。
图(a)展示了只有分组卷积的情况;不同颜色代表不同的分组,每个分组的输入的没有掺杂其他分组的特征,这就相当于各自管各自的,导致了分组之间信息的闭塞。如果允许每个分组卷积获取不同组的特征,如图(b)所示,将 GConv1 所有分组的输出特征 Feature 都根据组数均匀分发,作为 GConv2 每个分组的输入,那么输出(Output)和输入(Input)通道就完全相关了。这种混洗操作可以通过图(c)的通道混洗高效优雅地实现。图(c)展示了既使用分组卷积又使用通道混洗时的情况,它打破了分组卷积带来的信息隔离。箭头表示通道间的交换过程,将每个组内的通道重新排列。通过通道混洗,ShuffleNet V1能够更好地混合来自不同组的信息,从而增强模型的表示能力和泛化能力。
工作原理
- 输入分组:
- 输入张量仍然按照 G G G 个组进行划分。
- 混洗操作:
- 混洗操作首先将每个组内的通道重新排列,然后将所有组的通道重新组合。
- 具体来说,假设每个组有 C G \frac{C}{G} GC 个通道,混洗操作将每个通道从第 i i i 组移动到第 i i i 个位置的新组,其中 i i i 是通道的索引除以 C G \frac{C}{G} GC 的余数。
- 输出:
- 经过混洗操作后的输出张量保留了原始张量的形状,但通道间的信息得到了更好的混合。
优势
- 信息流动:
- 通道混洗促进了不同组之间的信息流动,有助于提高模型的泛化能力。
- 计算效率:
- 通道混洗操作本身计算非常简单,几乎不增加额外的计算负担。
3. ShuffleNet Unit (ShuffleNet基础单元)
定义
ShuffleNet的基础单元(ShuffleNet Unit)是网络的基本构建块,它结合了分组卷积和通道混洗操作,以实现高效的计算。
图(a)是一个典型的带有深度可分离卷积残差结构,ShuffleNet_V1 在此基础上设计出 ShuffleNet 单元。图(b)则是 stride=1 时的 ShuffleNet 单元,使用 1x1 分组卷积代替密集的 1x1 卷积,降低原 1x1 卷积的开销,同时加入 Channel Shuffle 实现跨通道信息交流。图(c)则是 stride=2 时的 ShuffleNet 单元,因为需要对特征图进行下采样,因此在图(b)结构基础上对残差连接分支采用 stride=2 的 3x3 全局平局池化,然后将主干输出特征和分支特征进行 concat,而不再是 add,大大的降低计算量与参数大小。
优势
- 计算效率:
- ShuffleNet Unit 利用分组卷积和通道混洗操作,显著减少了计算量。
- 表示能力:
- 通过通道混洗增强了信息流动,提高了模型的表示能力。
- 灵活性:
- 可以通过调整分组数 ( G ) 来控制模型的复杂度。
总结
ShuffleNet V1 通过使用分组卷积来减少计算量,通过通道混洗来促进不同组之间的信息流动,以及通过ShuffleNet基础单元来构建整个网络结构。这些技术共同作用,使得ShuffleNet V1 能够在资源受限的设备上高效运行,同时保持较高的分类准确率。
三. ShuffleNet v1 神经网络结构
ShuffleNet V1 是一种专门为移动设备设计的高效卷积神经网络,它采用了分组卷积和通道混洗操作来降低计算复杂度。下面是对ShuffleNet V1结构的详细描述:
层次结构
ShuffleNet V1 由一系列阶段(Stage)组成,每个阶段都包含多个基本单元(Unit)。表中列出了各个阶段及其对应的输出大小、卷积核大小(KSize)、步长(Stride)和重复次数(Repeat)。此外,还给出了不同分组数(g groups)下的输出通道数。
阶段描述
-
Conv1:
- 这是第一个卷积层,用于初始化输入图像。它采用 3 × 3 3 \times 3 3×3 的卷积核,步长为 2 2 2 ,输出大小为 112 × 112 112 \times 112 112×112 。
- 输出通道数为 24 24 24 ,对于所有分组数都是相同的。
-
MaxPool:
- 接着是一个最大池化层,同样采用 3 × 3 3 \times 3 3×3 的核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56 。
-
Stage2:
- Stage2 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 28 × 28 28 \times 28 28×28 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Stage3:
- Stage3 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 14 × 14 14 \times 14 14×14 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 7 7 7 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Stage4:
- Stage4包含两个基本单元,每个单元由分组卷积和通道混洗构成。
- 输出大小为 7 × 7 7 \times 7 7×7 。
- 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 次 s t r i d e = 1 stride=1 stride=1 的基本单元。
- 输出通道数随着分组数的增加而增加。
-
Global Pooling:
- 最后是一个全局平均池化层,将每个特征图压缩为单个值。
-
FC:
- 最终是一个全连接层,用于分类任务,输出类别数为 1000 1000 1000 。
计算复杂度
表的最后一列列出了不同分组数下的计算复杂度。可以看出,随着分组数的增加,计算复杂度逐渐降低。这是因为分组卷积将输入通道划分为若干组,每个组内的通道单独进行卷积运算,从而减少了计算量。
计算复杂度对比
上图展示了 ShuffleNet 与 ResNet 和 ResNeXt 在计算复杂度上的比较,并且包含了 ShuffleNet 的一个基本单元的结构图。
左边的图展示了一个标准的 ResNet 或 ResNeXt 的基本单元结构,包括三个主要部分:
1x1 Conv
:一个 1 × 1 1 \times 1 1×1 卷积层,用于减少输入通道的数量。3x3 DWConv
:一个 3 × 3 3 \times 3 3×3 深度可分离卷积层,用于提取空间特征。1x1 Conv
:另一个 1 × 1 1 \times 1 1×1 卷积层,用于恢复通道数量到原始值。
右边的图展示了 ShuffleNet 的基本单元结构,也包含三个部分:
1x1 GConv
:一个 1 × 1 1 \times 1 1×1 分组卷积层,用于减少输入通道的数量。Channel Shuffle
:通道混洗操作,用于打破分组卷积导致的信息隔离。1x1 GConv
:另一个 1 × 1 1 × 1 1×1 分组卷积层,用于恢复通道数量到原始值。
右侧的文字部分指出,由于分组卷积和通道混洗的存在,ShuffleNet 可以在相同设置下比 ResNet 和 ResNeXt 拥有更低的计算复杂度。具体来说,在给定输入尺寸 c × h × w c \times h \times w c×h×w 和瓶颈通道数 m m m 的情况下,ResNet 需要 h w ( 2 c m + 9 m 2 ) hw(2cm + 9m^2) hw(2cm+9m2) 浮点运算次数(FLOPs),ResNeXt 需要 h w ( 2 c m + 9 m 2 / g ) hw(2cm + 9m^2/g) hw(2cm+9m2/g) FLOPs,而ShuffleNet 只需要 h w ( 2 c m / g + 9 m ) hw(2cm/g + 9m) hw(2cm/g+9m) FLOPs,其中 g g g 表示分组数。
下方的公式进一步说明了这一点:
- ResNet 的计算复杂度为 h w ( 2 c m + 9 m 2 ) hw(2cm + 9m^2) hw(2cm+9m2)
- ResNeXt 的计算复杂度为 h w ( 2 c m + 9 m 2 / g ) hw(2cm + 9m^2/g) hw(2cm+9m2/g)
- ShuffleNet 的计算复杂度为 h w ( 2 c m / g + 9 m ) hw(2cm/g + 9m) hw(2cm/g+9m)
因此,ShuffleNet 通过分组卷积和通道混洗有效地降低了计算复杂度,使其更适合于计算资源有限的环境,如移动设备和嵌入式系统。
四. ShuffleNet v1 代码实现
开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。
- model.py:定义网络模型
- train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
- predict.py:用自己的数据集进行分类测试
- utils.py:依赖脚本
- my_dataset.py:依赖脚本
通道混洗代码及其示意图如下所示:
- 通道混洗代码
# 定义channel_shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
# 获取输入x的[B, C, H, W]
batch_size, num_channels, height, width = x.size()
# 获取每个组的channel
channel_per_group = num_channels // groups
# reshape
# [B, C, H, W] -> [B, G, C, H, W]
x = x.view(batch_size, groups, channel_per_group, height, width)
# 调换维度1和维度2 -> [G, B, C, H, W]
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
- 通道混洗示意图
- model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
"""3x3 convolution with padding
"""
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def conv1x1(in_channels, out_channels, groups=1):
"""1x1 convolution with padding
- Normal pointwise convolution When groups == 1
- Grouped pointwise convolution when groups > 1
"""
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups # groups是分的组数
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class ShuffleUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super(ShuffleUnit, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grouped_conv = grouped_conv
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4
# define the type of ShuffleUnit
if self.combine == 'add':
# ShuffleUnit Figure 2b
self.depthwise_stride = 1
self._combine_func = self._add
elif self.combine == 'concat':
# ShuffleUnit Figure 2c
self.depthwise_stride = 2
self._combine_func = self._concat
# ensure output of concat has the same channels as
# original output channels.
self.out_channels -= self.in_channels
else:
raise ValueError("Cannot combine tensors with \"{}\"" \
"Only \"add\" and \"concat\" are" \
"supported".format(self.combine))
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
# NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
self.first_1x1_groups = self.groups if grouped_conv else 1
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
self.in_channels,
self.bottleneck_channels,
self.first_1x1_groups,
batch_norm=True,
relu=True
)
# 3x3 depthwise convolution followed by batch normalization
self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels, self.bottleneck_channels,
stride=self.depthwise_stride, groups=self.bottleneck_channels)
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
self.g_conv_1x1_expand = self._make_grouped_conv1x1(
self.bottleneck_channels,
self.out_channels,
self.groups,
batch_norm=True,
relu=False
)
@staticmethod
def _add(x, out):
# residual connection
return x + out
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
batch_norm=True, relu=False):
modules = OrderedDict()
conv = conv1x1(in_channels, out_channels, groups=groups)
modules['conv1x1'] = conv
if batch_norm:
modules['batch_norm'] = nn.BatchNorm2d(out_channels)
if relu:
modules['relu'] = nn.ReLU()
if len(modules) > 1:
return nn.Sequential(modules)
else:
return conv
def forward(self, x):
# save for combining later with output
residual = x
if self.combine == 'concat':
residual = F.avg_pool2d(residual, kernel_size=3,
stride=2, padding=1)
out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out)
out = self.bn_after_depthwise(out)
out = self.g_conv_1x1_expand(out)
out = self._combine_func(residual, out)
return F.relu(out)
class ShuffleNet(nn.Module):
"""ShuffleNet implementation.
"""
def __init__(self, groups=3, in_channels=3, num_classes=1000):
"""ShuffleNet constructor.
Arguments:
groups (int, optional): number of groups to be used in grouped
1x1 convolutions in each ShuffleUnit. Default is 3 for best
performance according to original paper.
in_channels (int, optional): number of channels in the input tensor.
Default is 3 for RGB image inputs.
num_classes (int, optional): number of classes to predict. Default
is 1000 for ImageNet.
"""
super(ShuffleNet, self).__init__()
self.groups = groups
self.stage_repeats = [3, 7, 3]
self.in_channels = in_channels
self.num_classes = num_classes
# index 0 is invalid and should never be called.
# only used for indexing convenience.
if groups == 1:
self.stage_out_channels = [-1, 24, 144, 288, 567]
elif groups == 2:
self.stage_out_channels = [-1, 24, 200, 400, 800]
elif groups == 3:
self.stage_out_channels = [-1, 24, 240, 480, 960]
elif groups == 4:
self.stage_out_channels = [-1, 24, 272, 544, 1088]
elif groups == 8:
self.stage_out_channels = [-1, 24, 384, 768, 1536]
else:
raise ValueError(
"""{} groups is not supported for
1x1 Grouped Convolutions""".format(groups))
# Stage 1 always has 24 output channels
self.conv1 = conv3x3(self.in_channels,
self.stage_out_channels[1], # stage 1
stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Stage 2
self.stage2 = self._make_stage(2)
# Stage 3
self.stage3 = self._make_stage(3)
# Stage 4
self.stage4 = self._make_stage(4)
# Global pooling:
# Undefined as PyTorch's functional API can be used for on-the-fly
# shape inference if input size is not ImageNet's 224x224
# Fully-connected classification layer
num_inputs = self.stage_out_channels[-1]
self.fc = nn.Linear(num_inputs, self.num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias is not None:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=0.001)
if m.bias is not None:
init.constant(m.bias, 0)
def _make_stage(self, stage):
modules = OrderedDict()
stage_name = "ShuffleUnit_Stage{}".format(stage)
# First ShuffleUnit in the stage
# 1. non-grouped 1x1 convolution (i.e. pointwise convolution)
# is used in Stage 2. Group convolutions used everywhere else.
grouped_conv = stage > 2
# 2. concatenation unit is always used.
first_module = ShuffleUnit(
self.stage_out_channels[stage - 1],
self.stage_out_channels[stage],
groups=self.groups,
grouped_conv=grouped_conv,
combine='concat'
)
modules[stage_name + "_0"] = first_module
# add more ShuffleUnits depending on pre-defined number of repeats
for i in range(self.stage_repeats[stage - 2]):
name = stage_name + "_{}".format(i + 1)
module = ShuffleUnit(
self.stage_out_channels[stage],
self.stage_out_channels[stage],
groups=self.groups,
grouped_conv=True,
combine='add'
)
modules[name] = module
return nn.Sequential(modules)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
# global average pooling layer
x = F.avg_pool2d(x, x.data.size()[-2:])
# flatten for input to fully-connected layer
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def shufflenet(num_classes=1000):
model = ShuffleNet(groups=3, in_channels=3, num_classes=num_classes)
return model
- train.py
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import shufflenet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
# 如果存在预训练权重则载入
model = shufflenet(num_classes=args.num_classes).to(device)
if args.weights != "":
if os.path.exists(args.weights):
weights_dict = torch.load(args.weights, map_location=device)
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
print(model.load_state_dict(load_weights_dict, strict=False))
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))
# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="E:/code/PyCharm_Projects/deep_learning/data_set/flower_data/flower_photos")
# 不使用预训练权重 default=''
parser.add_argument('--weights', type=str, default='',
help='initial weights path')
# 冻结除最后全连接层的所有权重 default=True
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
main(opt)
- predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import shufflenet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = "郁金香.png"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = shufflenet(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/model-0.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
- utils.py
import os
import sys
import json
import pickle
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 排序,保证各平台顺序一致
images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
assert len(train_images_path) > 0, "number of training images must greater than 0."
assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)
def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
optimizer.zero_grad()
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
loss.backward()
mean_loss = (mean_loss * step + loss.detach()) / (step + 1) # update mean losses
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return mean_loss.item()
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
# 验证样本总个数
total_num = len(data_loader.dataset)
# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
data_loader = tqdm(data_loader, file=sys.stdout)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
pred = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred, labels.to(device)).sum()
return sum_num.item() / total_num
- my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels