MindSpore提供基于Pipeline的数据引擎,通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。在本教程中,我们使用Mnist数据集,自动下载完成后,使用mindspore.dataset
提供的数据变换进行预处理。
在下载Mnist数据集后,使用mindspore.dataset
提供的数据变换进行预处理。
首先下载:
from download import download url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \ "notebook/datasets/MNIST_Data.zip" path = download(url, "./", kind="zip", replace=True)
数据下载完成后,获得数据集对象。
train_dataset = MnistDataset('MNIST_Data/train') test_dataset = MnistDataset('MNIST_Data/test')
打印数据集中包含的数据列名,用于dataset的预处理。
print(train_dataset.get_col_names())
输出为:['image', 'label']
确认无误后,开始处理图片格式:
def datapipe(dataset, batch_size): image_transforms = [ vision.Rescale(1.0 / 255.0, 0), vision.Normalize(mean=(0.1307,), std=(0.3081,)), vision.HWC2CHW() ] label_transform = transforms.TypeCast(mindspore.int32) dataset = dataset.map(image_transforms, 'image') dataset = dataset.map(label_transform, 'label') dataset = dataset.batch(batch_size) return dataset
# Map vision transforms and batch dataset train_dataset = datapipe(train_dataset, 64) test_dataset = datapipe(test_dataset, 64)
这一步主要是使用map对图像数据及标签进行变换处理,将输入的图像缩放为1/255,根据均值0.1307和标准差值0.3081进行归一化处理,然后将处理好的数据集打包为大小为64的batch。
之后使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
输出准确,确认无误。
开始构建网络。
mindspore.nn
类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell
类,并重写__init__
方法和construct
方法。__init__
包含所有网络层的定义,construct
中包含数据(Tensor)的变换过程。
模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:
-
正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
-
反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
参数优化:将梯度更新到参数上。
-
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
-
定义正向计算函数。
-
使用value_and_grad通过函数变换获得梯度计算函数。
-
定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。
代码与输出如下:
模型训练完成后,需要将其参数进行保存。还要加载:
加载后的模型,可直接用于推理: