第二节快速入门举了一个例子,可以更加清晰地看到数据是如何被处理、模型是如何被训练的。
1. 环境准备
2. 处理数据集
在这里使用Mnist数据集,自动下载完成后,使用mindspore.dataset提供的数据变换进行预处理。
本章节中的示例代码依赖download,可使用命令pip install download安装。如本文档以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。
数据下载完成后,获得数据集对象。
打印数据集中包含的数据列名,用于dataset的预处理。
MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。这里我们使用map对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。
- vision.Rescale第一个参数代表rescale(缩放因子),第二个参数代表shift(平移因子)。基于给定的缩放和平移因子调整图像的像素大小。输出图像的像素大小为:output = image * rescale + shift。
- vision.Normalize根据均值和标准差对输入图像进行归一化。output[channel] = (input[channel] - mean[channel]) / std[channel],其中 channel 代表通道索引,channel >= 1。
- 第一个参数mean (sequence) - 图像每个通道的均值组成的列表或元组。平均值必须在 [0.0, 255.0] 范围内。
- 第二个参数std (sequence) - 图像每个通道的标准差组成的列表或元组。标准差值必须在 (0.0, 255.0] 范围内。
- vision.HWC2CHW的作用: shape (H, W, C) to shape (C, H, W).
- transforms.TypeCast将输入的Tensor转换为指定的数据类型。
可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatyp
Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
Shape of label: (64,) Int32
Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
Shape of label: (64,) Int32
3. 网络构建
mindspore.nn类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。
Network<
(flatten): Flatten<>
(dense_relu_sequential): SequentialCell<
(0): Dense<input_channels=784, output_channels=512, has_bias=True>
(1): ReLU<>
(2): Dense<input_channels=512, output_channels=512, has_bias=True>
(3): ReLU<>
(4): Dense<input_channels=512, output_channels=10, has_bias=True>
4. 模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:
正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
参数优化:将梯度更新到参数上。
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
定义正向计算函数。
使用value_and_grad通过函数变换获得梯度计算函数。
定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
除训练外,我们定义测试函数,用来评估模型的性能。
训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。
5. 保存模型
模型训练完成后,需要将其参数进行保存。
6. 加载模型
加载保存的权重分为两步:
- 重新实例化模型对象,构造模型。
- 加载模型参数,并将其加载至模型上。
param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。
加载后的模型可以直接用于预测推理。