这次学习了如何使用 MindSpore 来实现手写数字识别任务。以下是主要步骤:
-
下载并处理数据集:
- 可以从华为云下载数据集,速度更快。
- MNIST 数据集包含 10 类 28x28 的灰度图像,训练集有 60000 张图片,测试集有 10000 张图片。
-
创建模型:
- 使用自定义网络模型,共5层(2个Relu层、3个Dense层)。
-
定义损失函数和优化器:
- 使用交叉熵损失函数
CrossEntropyLoss
。 - 使用SGD优化器。
- 使用交叉熵损失函数
-
训练及保存模型:
- 使用
value_and_grad
接口生成求导函数,用于计算forward函数的正向计算结果和梯度。 - 使用
save_checkpoint
接口保存网络模型和参数。
- 使用
-
加载及使用模型:
- 使用
load_checkpoint
接口加载参数。 - 使用
load_param_into_net
接口加载参数到模型。 - 使用
set_train
接口设置模型为预测推理模式。
- 使用
教程来自:
https://gitee.com/mindspore/docs/blob/r2.3/tutorials/source_zh_cn/beginner/quick_start.ipynb