处理数据集
首先安装download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
用上述代码可以实现下载并解压所需文件到指定位置,
下载完成之后,用mindspore.dataset对提供的数据进行预处理。
网络构建
继承nn.Cell
类,并重写__init__
方法和construct
方法,可以实现自定义网络。
关于训练
三步走:
第一步、正向计算:
目的是:模型预测结果,并与正确标签求预测损失;
第二步、反向传播:
目的是:利用自动微分机制,自动求模型参数对于loss的梯度;
第三步、参数优化:
目的是:将梯度更新到参数上。
收尾
保存模型、加载模型