一、操作流程总结
-
数据集准备与加载
- 下载 CIFAR-10 数据集,并解压。
- 使用
mindspore.dataset.Cifar10Dataset
接口加载数据集,并进行图像增强操作,包括随机裁剪、水平翻转、调整大小、缩放、归一化等。 - 将数据集分为训练集和测试集,并获取数据集的大小。
-
构建网络
- 残差网络结构,包括两种类型:Building Block(适用于较浅的 ResNet 网络)和 Bottleneck(适用于层数较深的 ResNet 网络,如 ResNet50)。
- 分别定义了实现 Building Block 和 Bottleneck 结构的类
ResidualBlockBase
和ResidualBlock
。 - 通过
make_layer
函数实现残差块的构建。 - 定义
ResNet
类构建 ResNet50 网络,包括不同的卷积层、池化层、平均池化层和全连接层。
-
模型训练与评估
- 调用
resnet50
构造 ResNet50 模型,并设置pretrained=True
加载预训练模型。 - 由于 CIFAR10 数据集分类数与预训练模型不同,重置全连接层输出大小。
- 设置学习率、优化器和损失函数。
- 定义训练和评估的函数,进行多个 epoch 的训练,并保存评估精度最高的模型。
- 定义可视化模型预测的函数,对测试数据集的预测结果进行可视化展示。
- 调用
二、函数、参数与库总结
-
create_dataset_cifar10
函数- 参数:
dataset_dir
:数据集根目录。usage
:指定数据集的使用方式,如训练或测试。resize
:图像调整大小的尺寸。batch_size
:批量大小。workers
:并行线程个数。
- 功能:根据指定参数加载和处理 CIFAR-10 数据集,并进行数据映射和批量操作。
- 参数:
-
ResidualBlockBase
类- 参数:
in_channel
:输入通道数。out_channel
:输出通道数。stride
:卷积步幅。norm
:可选的归一化层。down_sample
:下采样操作。
- 功能:实现 Building Block 结构的残差块。
- 参数:
-
ResidualBlock
类- 参数:
in_channel
:输入通道数。out_channel
:输出通道数。stride
:卷积步幅。down_sample
:下采样操作。
- 功能:实现 Bottleneck 结构的残差块。
- 参数:
-
make_layer
函数- 参数:
last_out_channel
:上一个残差网络输出的通道数。block
:残差网络的类别,如ResidualBlockBase
或ResidualBlock
。channel
:残差网络输入的通道数。block_nums
:残差网络块堆叠的个数。stride
:卷积移动的步幅。
- 功能:构建残差网络层。
- 参数:
-
ResNet
类- 参数:
block
:残差块的类型。layer_nums
:每个残差网络结构块堆叠的个数列表。num_classes
:分类的类别数。input_channel
:输入通道数。
- 功能:构建 ResNet 模型。
- 参数:
-
mindspore
库mindspore.nn
:包含神经网络层、损失函数等模块。mindspore.dataset
:用于数据加载、处理和操作。mindspore.ops
:提供各种操作符和函数。
三、个人思考与阐述
- 残差网络结构的创新性在于通过引入 shortcuts 分支,有效地解决了传统卷积神经网络在加深时出现的退化问题,使得构建极深的网络成为可能,从而提高模型的表达能力和训练精度。
- 在数据加载和预处理部分,合理的图像增强操作(如随机裁剪、翻转等)可以增加数据的多样性,有助于提高模型的泛化能力。
- 对于不同深度的 ResNet 网络选择不同的残差结构(Building Block 和 Bottleneck),体现了在模型设计中根据计算资源和性能需求进行权衡的思想。
- 模型训练中的超参数设置(如学习率、优化器的参数、epoch 数量等)对模型的训练效果有重要影响,需要根据具体问题和数据集进行调整和优化。
- 可视化模型预测结果有助于直观地了解模型的性能和错误类型,为进一步改进模型提供了直观的依据。