经典CNN模型(十一):ShuffleNetV2(PyTorch详细注释版)

一. ShuffleNet V2 神经网络介绍

ShuffleNet V2 是对 ShuffleNet 的改进版本,旨在解决其前代的一些问题,例如低分辨率和通道混洗效率低下等。以下是 ShuffleNet V2 的一些关键特点:

ShuffleNet V2 特点

  1. 优化的分组卷积:
    • ShuffleNet V2 使用了一种称为“channel split”的技术,该技术将输入通道分成两半,分别进行不同的处理,然后合并结果以获得更好的性能。
  2. 自适应分组卷积:
    • ShuffleNet V2 根据输入数据动态调整分组数量,以实现更高的效率。
  3. 多尺度特征融合:
    • ShuffleNet V2 引入了多尺度特征融合模块,以更好地捕捉不同尺度的特征。
  4. 通道剪枝:
    • ShuffleNet V2 应用通道剪枝策略来进一步减少计算复杂度,同时保持准确性。

ShuffleNet V2 结构

ShuffleNet V2 仍然基于分组卷积和通道混洗,但引入了一些新的设计决策来改善性能和效率。例如,它使用了自适应分组卷积,根据输入数据自动选择合适的分组数量。此外,它还增加了多尺度特征融合模块,以捕获不同尺度的特征,这对于识别复杂的对象和场景非常有用。

ShuffleNet V2 单元

ShuffleNet V2 单元包含以下组件:

  1. 通道分割:
    • 输入通道被分割成两部分,每部分独立处理。
  2. 分组卷积:
    • 分割后的通道各自执行分组卷积。
  3. 通道混洗:
    • 混洗操作用于打破分组卷积中的信息隔离。
  4. 多尺度特征融合:
    • 不同尺度的特征被融合在一起,以增强模型的表现力。
  5. 通道剪枝:
    • 通过剪枝策略减少不必要的通道,以降低计算复杂度。

总结

ShuffleNet V2 在保留了原版 ShuffleNet 的高效性的同时,通过引入自适应分组卷积、多尺度特征融合和通道剪枝等方法,提高了模型的性能和灵活性。这种改进使得 ShuffleNet V2 更适用于资源受限的设备,如手机和嵌入式系统。

二. ShuffleNet V2 神经网络细节

ShuffleNet V2的设计考虑了多个因素以提高计算效率和性能。以下是更详细的介绍:

1. ShuffleNet V2 设计准则

在这里插入图片描述

G1:内存访问成本最小化

ShuffleNet V2 试图最小化内存访问成本(MAC)。在深度可分离卷积中, 1 × 1 1×1 1×1 卷积的 FLOPs 为 B = h w c 1 c 2 B = hwc_{1}c_{2} B=hwc1c2 ,内存访问成本为: B = h w ( c 1 + c 2 ) + c 1 c 2 B = hw(c_{1}+c_{2})+c_{1}c_{2} B=hw(c1+c2)+c1c2 根据均值不等式,MAC 至少为: 2 h w B + B h w 2\sqrt{hwB}+\frac{B}{hw} 2hwB +hwB 只有当 c 1 = c 2 c_{1} = c_{2} c1=c2 时,MAC 才能取到最小值。
在这里插入图片描述

G2:避免过多使用分组卷积

对于分组卷积,FLOPs 为 B = h w c 1 c 2 / g B = hwc_{1}c_{2}/g B=hwc1c2/g ,MAC 为: h w ( c 1 + c 2 ) + c 1 c 2 / g hw(c_{1}+c_{2})+c_{1}c_{2}/g hw(c1+c2)+c1c2/g 如果固定输入 c 1 × h × w c_{1}\times h \times w c1×h×w B B B ,则 MAC 为: h w c 1 + B g / c 1 + B / h w hwc_{1}+Bg/c_{1}+B/hw hwc1+Bg/c1+B/hw随着分组数 g g g 的增加,MAC 也会增加。
在这里插入图片描述

G3:避免网络碎片化

ShuffleNet V2 避免使用多路结构,因为它会导致网络碎片化,从而减慢速度。

G4:重视元素级操作

ShuffleNet V2 注意到 ReLU、TensorAdd 和 BiasAdd 等元素级操作虽然 FLOPs 少,但 MAC 大。实验表明,移除残差网络中的 ReLU 和短接可以提高 20% 的速度。

因此,基于上述分析,作者得到了4条指导准则

  1. 使用“平衡”卷积层,即输入与输出通道相同;
  2. 谨慎使用分组卷积并注意分组数;
  3. 减少碎片化的操作;
  4. 减少元素级的操作。

2. ShuffleNet V2 Block

下图中的 a 与 b 是 ShuffleNet V1 的 architecture,c 与 d 是 ShuffleNet V2 的 architecture。仔细观察我们可以发现ShuffleNet V1中到处违背了4条设计原则:

  1. 使用了 bottleneck layer,使得输入输出通道数不同,违背了 G1 原则;
  2. 大量使用 1 × 1 1 × 1 1×1 卷积,违背了 G2 原则;
  3. 使用了过多的 group ,违背了 G3 原则;
  4. shortcut 中存在了大量的元素级 add 运算,违背了 G4 原则。

在这里插入图片描述
ShuffleNet V2 是对ShuffleNet V1 的改进版本,旨在解决前者的几个问题。以下是针对所提到的几点的详细分析:

  1. 通道拆分(Channel Split):
    • ShuffleNet V2 引入了一个新的操作——通道拆分,如图 C 和 D 所示。这将输入通道分成两个分支,其中一个分支执行恒等映射,保持输入和输出通道数不变,符合 G3 原则,即避免网络碎片化。另一个分支执行多层卷积,以确保输入和输出通道数相等,符合 G1 原则,即最小化内存访问成本。
  2. 取消分组卷积:
    • ShuffleNet V2 放弃了ShuffleNet V1 中使用的分组卷积,特别是 1x1 分组卷积,以减少内存访问成本,符合 G2 原则。
  3. 使用concatenate代替TensorAdd:
    • ShuffleNet V2 使用 concatenate 操作而不是 TensorAdd 来合并两个分支的输出,这是因为 concatenate 操作具有更低的计算复杂度,符合 G4 原则。

通过这些改进,ShuffleNet V2 在保持计算效率的同时提高了准确率,并减少了内存访问成本。此外,它还避免了网络碎片化,降低了计算复杂度,从而提高了整体性能。

三. ShuffleNet V2 神经网络结构

ShuffleNet V2 是一种专门为移动设备优化的高效卷积神经网络,它利用通道拆分和通道级联操作来提高计算效率。下面是ShuffleNet V2 结构的详细概述:

在这里插入图片描述

层次结构

ShuffleNet V2 由一系列阶段(Stage)组成,每个阶段都包含多个基本单元(Basic Unit)。表中列出了各个阶段及其对应的输出大小、卷积核大小(KSize)、步长(Stride)和重复次数(Repeat)。此外,还给出了不同缩放因子(Scaling Factor)下的输出通道数。

阶段描述

  1. Conv1:

    • 这是第一个卷积层,用于初始化输入图像。它采用 3 × 3 3 \times 3 3×3 的卷积核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56
    • 输出通道数为 24 24 24 ,对于所有缩放因子都是相同的。
  2. MaxPool:

    • 接着是一个最大池化层,同样采用 3 × 3 3 \times 3 3×3 的核,步长为 2 2 2 ,输出大小为 56 × 56 56 \times 56 56×56
  3. Stage2:

    • Stage2 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 28 × 28 28 \times 28 28×28
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  4. Stage3:

    • Stage3 包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 14 × 14 14 \times 14 14×14
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 7 7 7 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  5. Stage4:

    • Stage4包含两个基本单元,每个单元由分组卷积和通道混洗构成。
    • 输出大小为 7 × 7 7 \times 7 7×7
    • 首先采用 s t r i d e = 2 stride=2 stride=2 的基本单元,紧接着重复 3 3 3 s t r i d e = 1 stride=1 stride=1 的基本单元。
    • 输出通道数随着分组数的增加而增加。
  6. Conv5:

    • Conv5 是一个全局池化层,使用 1 × 1 1 \times 1 1×1 的卷积核进行全局池化操作。
    • 所有缩放因子的输出通道数均为 1024 1024 1024
  7. GlobalPool and FC:

    • 全局池化层将 7 × 7 7 \times 7 7×7 的特征图转换为一个向量,然后传递给全连接层进行预测。
    • 对应于所有缩放因子的 FLOPs 和权重数量均列出。

ShuffleNet V2 的设计目标是在保持准确性的同时尽可能地减少计算量。它通过通道拆分、通道混洗和适当的重复次数来实现这一点。此外,它还优化了内存访问成本,避免了不必要的分组卷积,并尽量减少了网络碎片化。

四. ShuffleNet V2 代码实现

开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。

  • model.py:定义网络模型
  • train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
  • predict.py:用自己的数据集进行分类测试
  • utils.py:依赖脚本
  • my_dataset.py:依赖脚本
  1. model.py
from typing import List,Callable
import torch
from torch import Tensor
import torch.nn as nn


#   定义channel_shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:

    #   获取输入x的[B, C, H, W]
    batch_size, num_channels, height, width = x.size()
    #   获取每个组的channel
    channel_per_group = num_channels // groups

    #   reshape
    #   [B, C, H, W] -> [B, G, C, H, W]
    x = x.view(batch_size, groups, channel_per_group, height, width)

    #   调换维度1和维度2 -> [G, B, C, H, W]
    x = torch.transpose(x, 1, 2).contiguous()

    #   flatten
    x = x.view(batch_size, -1, height, width)

    return x

class InvertedResidual(nn.Module):
    #   input_c:输入特征矩阵通道 output_c:输出特征矩阵通道 stride:DW卷积步幅
    def __init__(self, input_c: int, output_c: int, stride: int):
        super(InvertedResidual, self).__init__()

        #   判断stride是否只取1和2
        if stride not in [1, 2]:
            raise ValueError("illegal stride value")
        self.stride = stride

        #   判断output_c是否为2的整数倍(结构左右分支的通道数都是相同的)
        assert output_c % 2 == 0
        branch_features = output_c // 2

        #   当stride=1为1时,input_channel应该是branch_features的两倍
        #   python中 ”<<“ 是位运算,可理解为计算x2的快速方法
        assert (self.stride != 1) or (input_c == branch_features << 1)

        if self.stride == 2:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            self.branch1 = nn.Sequential()

        self.branch2 = nn.Sequential(
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    @staticmethod
    def depthwise_conv(input_c: int,
                       output_c: int,
                       kernel_s: int,
                       stride: int = 1,
                       padding: int = 0,
                       bias: int = False) -> nn.Conv2d:
        return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                         stride=stride, padding=padding, bias=bias, groups=input_c)

    def forward(self, x: Tensor) -> Tensor:
        if self.stride == 1:
            #   将channel均分成两份
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

class ShuffleNetV2(nn.Module):
    #   stage_repeats:Block重复次数
    def __init__(self,
                 stages_repeats: List[int],
                 stages_out_channels: List[int],
                 num_classes: int = 1000,
                 inverted_residual: Callable[..., nn.Module] = InvertedResidual):
        super(ShuffleNetV2, self).__init__()

        if len(stages_repeats) != 3:
            raise ValueError("expected stages_repeats as list of 3 positive ints")
        if len(stages_out_channels) != 5:
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
        self._stage_out_channels = stages_out_channels

        #   input RGB image
        input_channels = 3
        output_channels = self._stage_out_channels[0]

        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        #   声明stage实现方法
        self.stage2 = nn.Sequential
        self.stage3 = nn.Sequential
        self.stage4 = nn.Sequential

        stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
        for name, repeats, output_channels in zip(stage_names, stages_repeats,
                                                  self._stage_out_channels[1:]):
            seq = [inverted_residual(input_channels, output_channels, 2)]
            for i in range(repeats - 1):
                seq.append(inverted_residual(output_channels, output_channels, 1))
            #   使用setattr(self, name, nn.Sequential(*seq))来将创建好的nn.Sequential对象设置为当前类实例的一个属性,属性名为name
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

        output_channels = self._stage_out_channels[-1]
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )

        self.fc = nn.Linear(output_channels, num_classes)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # global pool
        x = self.fc(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def shufflenet_v2_x0_5(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 48, 96, 192, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x1_0(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 116, 232, 464, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x1_5(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 176, 352, 704, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x2_0(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 244, 488, 976, 2048],
                         num_classes=num_classes)

    return model
  1. train.py
import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)

        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.1)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="E:/code/PyCharm_Projects/deep_learning/data_set/flower_data/flower_photos")

    # shufflenetv2_x1.0 官方权重下载地址
    # https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
    # 不使用预训练权重 default=''
    parser.add_argument('--weights', type=str, default='./shufflenetv2_x1.pth',
                        help='initial weights path')
    # 冻结除最后全连接层的所有权重 default=True
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)
  1. predict.py
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import shufflenet_v2_x1_0


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "郁金香.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = shufflenet_v2_x1_0(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/model-0.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()
  1. utils.py
import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证各平台顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 排序,保证各平台顺序一致
        images.sort()
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    mean_loss = torch.zeros(1).to(device)
    optimizer.zero_grad()

    data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data

        pred = model(images.to(device))

        loss = loss_function(pred, labels.to(device))
        loss.backward()
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses

        data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()

    return mean_loss.item()


@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()

    # 验证样本总个数
    total_num = len(data_loader.dataset)

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)

    data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()

    return sum_num.item() / total_num
  1. my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

五. 参考内容

  1. 李沐. (2019). 动手学深度学习. 北京: 人民邮电出版社. [ISBN: 978-7-115-51364-9]
  2. 霹雳吧啦Wz. (202X). 深度学习实战系列 [在线视频]. 哔哩哔哩. URL
  3. PyTorch. (n.d.). PyTorch官方文档和案例 [在线资源]. URL
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值