1.mnist手写数字数据集
mnist手写数字数据集是经常被用来初学者使用的图片数据集,其图片大小为28*28,一般的图片都是(H,W,3),由于mnist数据集图片都是数字,所以可以不需要RGB,其数据集图片如下:
对于该数据集标签为数字所代表的值,如上图片为5,其对应的数字也是5,所以这是一个十分类的任务。
整个mnist有60000张训练集,10000张测试集,相较于类别数来说样本量相当充足。
2.模型构建与训练
第一步,在头部引入必要所需的环境包,代码如下:
import paddle
from paddle.vision.transforms import Normalize
第二步,获取mnist数据集,paddle直接就提供了,不过由于我们需要模型适配对应的格式,所以对mnist的数据集进行format的处理,代码如下:
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform&#