经典CNN模型(十):ShuffleNetV1(PyTorch详细注释版)

一. ShuffleNet v1 神经网络介绍

ShuffleNet V1 是一种专为移动设备设计的高效卷积神经网络(CNN),旨在解决计算资源有限的问题。它由阿里巴巴达摩院的研究人员提出,并在2018年的论文《ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices》中首次被介绍。以下是ShuffleNet V1的一些关键特点和组成部分:

关键特点

  1. Group Convolution (分组卷积):

    • ShuffleNet V1 使用分组卷积来减少计算成本。分组卷积将输入通道分成几个组,并在每个组内独立执行卷积操作。这类似于Inception模块中的并行分支,但更加高效。
  2. Channel Shuffle (通道重排):

    • 为了克服分组卷积导致的信息隔离问题,ShuffleNet V1 引入了通道重排操作。这个操作使得不同组之间的信息能够混合,从而提高了模型的表达能力。
  3. Depthwise Separable Convolutions (深度可分离卷积):

    • 尽管ShuffleNet V1 主要依赖于分组卷积,但它也使用深度可分离卷积作为其构建块之一,进一步减少计算量。
  4. Efficiency and Accuracy Trade-off (效率与准确性的权衡):

    • ShuffleNet V1 通过调整分组的数量来控制模型的复杂度,从而在计算效率和分类精度之间找到最佳平衡点。

架构组成

ShuffleNet V1 的基本单元称为 Shuffle 单元(Shuffle Unit),它通常包含两个子单元:基本的 Shuffle 单元和下采样 Shuffle 单元。

  1. Basic Shuffle Unit:

    • 该单元用于网络中的大部分部分,其中包含两个主要步骤:
      • 第一步使用逐点卷积(1x1卷积)进行通道降维。
      • 第二步包括两个并行路径:
        • 分组卷积(GConv):对输入进行分组卷积。
        • 直接传递:将一部分输入通道直接传递到下一个层。
      • 最后,通过通道重排(Channel Shuffle)将这两个路径的输出合并起来。
  2. Downsample Shuffle Unit:

    • 当需要减小空间尺寸时,会使用下采样 Shuffle 单元。这种单元通常包含一个额外的深度可分离卷积层来进行下采样操作。

应用场景

ShuffleNet V1 特别适合于移动设备和嵌入式系统,因为它在保证一定精度的前提下大大降低了计算需求。它适用于各种计算机视觉任务,包括但不限于图像分类、目标检测和语义分割。

总结

ShuffleNet V1 是一种轻量级的卷积神经网络架构,通过使用分组卷积和通道重排技术,在保持较高分类准确率的同时,极大地降低了模型的计算复杂度和内存占用。这种设计使其成为在资源受限设备上运行的理想选择。

二. ShuffleNet v1 神经网络细节

下面是关于ShuffleNet V1神经网络中三个关键概念的详细说明:分组卷积(Group Convolution)、通道混洗(Channel Shuffle)以及ShuffleNet基础单元(ShuffleNet Unit)。

1. Group Convolution (分组卷积)

定义

分组卷积是一种高效的卷积运算方式,它将输入通道分成多个组,并在每个组内独立进行卷积运算。这种做法减少了计算量和参数数量,尤其是在处理高分辨率图像时效果显著。

在这里插入图片描述

工作原理
  • 输入分组:
    • 假设输入张量具有 C C C 个通道,我们将这些通道分成 G G G 组,每组有 C G \frac{C}{G} GC 个通道。
  • 独立卷积:
    • 每个组内的通道只与本组的滤波器进行卷积,而不是整个输入张量的所有通道。
  • 输出拼接:
    • 每个组的输出结果被拼接在一起形成最终的输出张量。

在这里插入图片描述

优势
  • 计算效率:
    • 分组卷积显著减少了计算量,因为每个滤波器只需要处理输入的一部分。
  • 参数减少:
    • 由于每个滤波器只作用于一部分输入通道,因此模型的参数数量也随之减少。

2. Channel Shuffle (通道混洗)

定义

通道混洗是一种特殊的操作,用于在分组卷积之后混合不同组之间的信息。它的目的是打破分组卷积导致的信息隔离,提高模型的表示能力。

在这里插入图片描述
图(a)展示了只有分组卷积的情况;不同颜色代表不同的分组,每个分组的输入的没有掺杂其他分组的特征,这就相当于各自管各自的,导致了分组之间信息的闭塞。如果允许每个分组卷积获取不同组的特征,如图(b)所示,将 GConv1 所有分组的输出特征 Feature 都根据组数均匀分发,作为 GConv2 每个分组的输入,那么输出(Output)和输入(Input)通道就完全相关了。这种混洗操作可以通过图(c)的通道混洗高效优雅地实现。图(c)展示了既使用分组卷积又使用通道混洗时的情况,它打破了分组卷积带来的信息隔离。箭头表示通道间的交换过程,将每个组内的通道重新排列。通过通道混洗,ShuffleNet V1能够更好地混合来自不同组的信息,从而增强模型的表示能力和泛化能力。

工作原理
  • 输入分组:
    • 输入张量仍然按照 G G G 个组进行划分。
  • 混洗操作:
    • 混洗操作首先将每个组内的通道重新排列,然后将所有组的通道重新组合。
    • 具体来说,假设每个组有 C G \frac{C}{G} GC 个通道,混洗操作将每个通道从第 i i i 组移动到第 i i i 个位置的新组,其中 i i i 是通道的索引除以 C G \frac{C}{G} GC 的余数。
  • 输出:
    • 经过混洗操作后的输出张量保留了原始张量的形状,但通道间的信息得到了更好的混合。
优势
  • 信息流动:
    • 通道混洗促进了不同组之间的信息流动,有助于提高模型的泛化能力。
  • 计算效率:
    • 通道混洗操作本身计算非常简单,几乎不增加额外的计算负担。

3. ShuffleNet Unit (ShuffleNet基础单元)

定义

ShuffleNet的基础单元(ShuffleNet Unit)是网络的基本构建块,它结合了分组卷积和通道混洗操作,以实现高效的计算。

在这里插入图片描述

图(a)是一个典型的带有深度可分离卷积残差结构,ShuffleNet_V1 在此基础上设计出 ShuffleNet 单元。图(b)则是 stride=1 时的 ShuffleNet 单元,使用 1x1 分组卷积代替密集的 1x1 卷积,降低原 1x1 卷积的开销,同时加入 Channel Shuffle 实现跨通道信息交流。图(c)则是 stride=2 时的 ShuffleNet 单元,因为需要对特征图进行下采样,因此在图(b)结构基础上对残差连接分支采用 stride=2 的 3x3 全局平局池化,然后将主干输出特征和分支特征进行 concat,而不再是 add,大大的降低计算量与参数大小。

优势
  • 计算效率:
    • ShuffleNet Unit 利用分组卷积和通道混洗操作,显著减少了计算量。
  • 表示能力:
    • 通过通道混洗增强了信息流动,提高了模型的表示能力。
  • 灵活性:
    • 可以通过调整分组数 ( G ) 来控制模型的复杂度。

总结

ShuffleNet V1 通过使用分组卷积来减少计算量,通过通道混洗来促进不同组之间的信息流动,以及通过ShuffleNet基础单元来构建整个网络结构。这些技术共同作用,使得ShuffleNet V1 能够在资源受限的设备上高效运行,同时保持较高的分类准确率。

三. ShuffleNet v1 神经网络结构

ShuffleNet V1 是一种专门为移动设备设计的高效卷积神经网络,它采用了分组卷积和通道混洗操作来降低计算复杂度。下面是对ShuffleNet V1结构的详细描述:

在这里插入图片描述

层次结构

ShuffleNet V1 由一系列阶段(Stage)组成,每个阶段都包含多个基本单元(Unit)。表中列出了各个阶段及其对应的输出大小、卷积核大小(KSize)、步长(Stride)和重复次数(Repeat)。此外,还给出了不同分组数(g groups)下的输出通道数。

阶段描述

  1. Conv1:

    • 这是第一个卷积层,用于初始化输入图像。它采用 3 × 3 3 \times 3 3×3 的卷积核,步长为 2 2 2 ,输出大小为 112 × 112 112 \times 112 112×112
    • 输出通道数为 24 24 24 ,对于所有分组数都是相同的。
  2. MaxPool:

    • 接着是一个最大池化层,同样采用 3 × 3 3 \times 3 3×3 的核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56
  3. Stage2:

    • Stage2 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 28 × 28 28 \times 28 28×28
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  4. Stage3:

    • Stage3 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 14 × 14 14 \times 14 14×14
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 7 7 7 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  5. Stage4:

    • Stage4包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 7 × 7 7 \times 7 7×7
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  6. Global Pooling:

    • 最后是一个全局平均池化层,将每个特征图压缩为单个值。
  7. FC:

    • 最终是一个全连接层,用于分类任务,输出类别数为 1000 1000 1000

计算复杂度

表的最后一列列出了不同分组数下的计算复杂度。可以看出,随着分组数的增加,计算复杂度逐渐降低。这是因为分组卷积将输入通道划分为若干组,每个组内的通道单独进行卷积运算,从而减少了计算量。

计算复杂度对比

在这里插入图片描述

上图展示了 ShuffleNet 与 ResNet 和 ResNeXt 在计算复杂度上的比较,并且包含了 ShuffleNet 的一个基本单元的结构图。

左边的图展示了一个标准的 ResNet 或 ResNeXt 的基本单元结构,包括三个主要部分:

  1. 1x1 Conv:一个 1 × 1 1 \times 1 1×1 卷积层,用于减少输入通道的数量。
  2. 3x3 DWConv:一个 3 × 3 3 \times 3 3×3 深度可分离卷积层,用于提取空间特征。
  3. 1x1 Conv:另一个 1 × 1 1 \times 1 1×1 卷积层,用于恢复通道数量到原始值。

右边的图展示了 ShuffleNet 的基本单元结构,也包含三个部分:

  1. 1x1 GConv:一个 1 × 1 1 \times 1 1×1 分组卷积层,用于减少输入通道的数量。
  2. Channel Shuffle:通道混洗操作,用于打破分组卷积导致的信息隔离。
  3. 1x1 GConv:另一个 1 × 1 1 × 1 1×1 分组卷积层,用于恢复通道数量到原始值。

右侧的文字部分指出,由于分组卷积和通道混洗的存在,ShuffleNet 可以在相同设置下比 ResNet 和 ResNeXt 拥有更低的计算复杂度。具体来说,在给定输入尺寸 c × h × w c \times h \times w c×h×w 和瓶颈通道数 m m m 的情况下,ResNet 需要 h w ( 2 c m + 9 m 2 ) hw(2cm + 9m^2) hw(2cm+9m2) 浮点运算次数(FLOPs),ResNeXt 需要 h w ( 2 c m + 9 m 2 / g ) hw(2cm + 9m^2/g) hw(2cm+9m2/g) FLOPs,而ShuffleNet 只需要 h w ( 2 c m / g + 9 m ) hw(2cm/g + 9m) hw(2cm/g+9m) FLOPs,其中 g g g 表示分组数。

下方的公式进一步说明了这一点:

  • ResNet 的计算复杂度为 h w ( 2 c m + 9 m 2 ) hw(2cm + 9m^2) hw(2cm+9m2)
  • ResNeXt 的计算复杂度为 h w ( 2 c m + 9 m 2 / g ) hw(2cm + 9m^2/g) hw(2cm+9m2/g)
  • ShuffleNet 的计算复杂度为 h w ( 2 c m / g + 9 m ) hw(2cm/g + 9m) hw(2cm/g+9m)

因此,ShuffleNet 通过分组卷积和通道混洗有效地降低了计算复杂度,使其更适合于计算资源有限的环境,如移动设备和嵌入式系统。

四. ShuffleNet v1 代码实现

开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。

  • model.py:定义网络模型
  • train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
  • predict.py:用自己的数据集进行分类测试
  • utils.py:依赖脚本
  • my_dataset.py:依赖脚本

通道混洗代码及其示意图如下所示:

  • 通道混洗代码
#   定义channel_shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:

    #   获取输入x的[B, C, H, W]
    batch_size, num_channels, height, width = x.size()
    #   获取每个组的channel
    channel_per_group = num_channels // groups

    #   reshape
    #   [B, C, H, W] -> [B, G, C, H, W]
    x = x.view(batch_size, groups, channel_per_group, height, width)

    #   调换维度1和维度2 -> [G, B, C, H, W]
    x = torch.transpose(x, 1, 2).contiguous()

    #   flatten
    x = x.view(batch_size, -1, height, width)

    return x
  • 通道混洗示意图

在这里插入图片描述

  1. model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init


def conv3x3(in_channels, out_channels, stride=1,
            padding=1, bias=True, groups=1):
    """3x3 convolution with padding
    """
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)


def conv1x1(in_channels, out_channels, groups=1):
    """1x1 convolution with padding
    - Normal pointwise convolution When groups == 1
    - Grouped pointwise convolution when groups > 1
    """
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups  # groups是分的组数

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)

    # transpose
    # - contiguous() required if transpose() is used before view().
    #   See https://github.com/pytorch/pytorch/issues/764
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class ShuffleUnit(nn.Module):
    def __init__(self, in_channels, out_channels, groups=3,
                 grouped_conv=True, combine='add'):

        super(ShuffleUnit, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grouped_conv = grouped_conv
        self.combine = combine
        self.groups = groups
        self.bottleneck_channels = self.out_channels // 4

        # define the type of ShuffleUnit
        if self.combine == 'add':
            # ShuffleUnit Figure 2b
            self.depthwise_stride = 1
            self._combine_func = self._add
        elif self.combine == 'concat':
            # ShuffleUnit Figure 2c
            self.depthwise_stride = 2
            self._combine_func = self._concat

            # ensure output of concat has the same channels as
            # original output channels.
            self.out_channels -= self.in_channels
        else:
            raise ValueError("Cannot combine tensors with \"{}\"" \
                             "Only \"add\" and \"concat\" are" \
                             "supported".format(self.combine))

        # Use a 1x1 grouped or non-grouped convolution to reduce input channels
        # to bottleneck channels, as in a ResNet bottleneck module.
        # NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
        self.first_1x1_groups = self.groups if grouped_conv else 1

        self.g_conv_1x1_compress = self._make_grouped_conv1x1(
            self.in_channels,
            self.bottleneck_channels,
            self.first_1x1_groups,
            batch_norm=True,
            relu=True
        )

        # 3x3 depthwise convolution followed by batch normalization
        self.depthwise_conv3x3 = conv3x3(
            self.bottleneck_channels, self.bottleneck_channels,
            stride=self.depthwise_stride, groups=self.bottleneck_channels)
        self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)

        # Use 1x1 grouped convolution to expand from
        # bottleneck_channels to out_channels
        self.g_conv_1x1_expand = self._make_grouped_conv1x1(
            self.bottleneck_channels,
            self.out_channels,
            self.groups,
            batch_norm=True,
            relu=False
        )

    @staticmethod
    def _add(x, out):
        # residual connection
        return x + out

    @staticmethod
    def _concat(x, out):
        # concatenate along channel axis
        return torch.cat((x, out), 1)

    def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
                              batch_norm=True, relu=False):

        modules = OrderedDict()

        conv = conv1x1(in_channels, out_channels, groups=groups)
        modules['conv1x1'] = conv

        if batch_norm:
            modules['batch_norm'] = nn.BatchNorm2d(out_channels)
        if relu:
            modules['relu'] = nn.ReLU()
        if len(modules) > 1:
            return nn.Sequential(modules)
        else:
            return conv

    def forward(self, x):
        # save for combining later with output
        residual = x

        if self.combine == 'concat':
            residual = F.avg_pool2d(residual, kernel_size=3,
                                    stride=2, padding=1)

        out = self.g_conv_1x1_compress(x)
        out = channel_shuffle(out, self.groups)
        out = self.depthwise_conv3x3(out)
        out = self.bn_after_depthwise(out)
        out = self.g_conv_1x1_expand(out)

        out = self._combine_func(residual, out)
        return F.relu(out)


class ShuffleNet(nn.Module):
    """ShuffleNet implementation.
    """

    def __init__(self, groups=3, in_channels=3, num_classes=1000):
        """ShuffleNet constructor.

        Arguments:
            groups (int, optional): number of groups to be used in grouped
                1x1 convolutions in each ShuffleUnit. Default is 3 for best
                performance according to original paper.
            in_channels (int, optional): number of channels in the input tensor.
                Default is 3 for RGB image inputs.
            num_classes (int, optional): number of classes to predict. Default
                is 1000 for ImageNet.

        """
        super(ShuffleNet, self).__init__()

        self.groups = groups
        self.stage_repeats = [3, 7, 3]
        self.in_channels = in_channels
        self.num_classes = num_classes

        # index 0 is invalid and should never be called.
        # only used for indexing convenience.
        if groups == 1:
            self.stage_out_channels = [-1, 24, 144, 288, 567]
        elif groups == 2:
            self.stage_out_channels = [-1, 24, 200, 400, 800]
        elif groups == 3:
            self.stage_out_channels = [-1, 24, 240, 480, 960]
        elif groups == 4:
            self.stage_out_channels = [-1, 24, 272, 544, 1088]
        elif groups == 8:
            self.stage_out_channels = [-1, 24, 384, 768, 1536]
        else:
            raise ValueError(
                """{} groups is not supported for
                   1x1 Grouped Convolutions""".format(groups))

        # Stage 1 always has 24 output channels
        self.conv1 = conv3x3(self.in_channels,
                             self.stage_out_channels[1],  # stage 1
                             stride=2)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Stage 2
        self.stage2 = self._make_stage(2)
        # Stage 3
        self.stage3 = self._make_stage(3)
        # Stage 4
        self.stage4 = self._make_stage(4)

        # Global pooling:
        # Undefined as PyTorch's functional API can be used for on-the-fly
        # shape inference if input size is not ImageNet's 224x224

        # Fully-connected classification layer
        num_inputs = self.stage_out_channels[-1]
        self.fc = nn.Linear(num_inputs, self.num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant(m.weight, 1)
                init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant(m.bias, 0)

    def _make_stage(self, stage):
        modules = OrderedDict()
        stage_name = "ShuffleUnit_Stage{}".format(stage)

        # First ShuffleUnit in the stage
        # 1. non-grouped 1x1 convolution (i.e. pointwise convolution)
        #   is used in Stage 2. Group convolutions used everywhere else.
        grouped_conv = stage > 2

        # 2. concatenation unit is always used.
        first_module = ShuffleUnit(
            self.stage_out_channels[stage - 1],
            self.stage_out_channels[stage],
            groups=self.groups,
            grouped_conv=grouped_conv,
            combine='concat'
        )
        modules[stage_name + "_0"] = first_module

        # add more ShuffleUnits depending on pre-defined number of repeats
        for i in range(self.stage_repeats[stage - 2]):
            name = stage_name + "_{}".format(i + 1)
            module = ShuffleUnit(
                self.stage_out_channels[stage],
                self.stage_out_channels[stage],
                groups=self.groups,
                grouped_conv=True,
                combine='add'
            )
            modules[name] = module

        return nn.Sequential(modules)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)

        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        # global average pooling layer
        x = F.avg_pool2d(x, x.data.size()[-2:])

        # flatten for input to fully-connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x



def shufflenet(num_classes=1000):
    model = ShuffleNet(groups=3, in_channels=3, num_classes=num_classes)
    return model
  1. train.py
import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import shufflenet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = shufflenet(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)

        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.1)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="E:/code/PyCharm_Projects/deep_learning/data_set/flower_data/flower_photos")

    # 不使用预训练权重 default=''
    parser.add_argument('--weights', type=str, default='',
                        help='initial weights path')
    # 冻结除最后全连接层的所有权重 default=True
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)
  1. predict.py
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import shufflenet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "郁金香.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = shufflenet(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/model-0.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()
  1. utils.py
import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证各平台顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 排序,保证各平台顺序一致
        images.sort()
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    mean_loss = torch.zeros(1).to(device)
    optimizer.zero_grad()

    data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data

        pred = model(images.to(device))

        loss = loss_function(pred, labels.to(device))
        loss.backward()
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses

        data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()

    return mean_loss.item()


@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()

    # 验证样本总个数
    total_num = len(data_loader.dataset)

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)

    data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()

    return sum_num.item() / total_num
  1. my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

五. 参考内容

  1. 李沐. (2019). 动手学深度学习. 北京: 人民邮电出版社. [ISBN: 978-7-115-51364-9]
  2. 霹雳吧啦Wz. (202X). 深度学习实战系列 [在线视频]. 哔哩哔哩. URL
  3. PyTorch. (n.d.). PyTorch官方文档和案例 [在线资源]. URL
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值