上一篇,我们主要介绍了 CNN 的基本概念和模型结构。本文将带领大家使用 PyTorch 一步步搭建 CNN 模型,进行数字图片识别。本案例中,我们选用的是 MNIST 数据集。
总的来说,我们构建分类器将按照以下步骤来做:
- 使用 torchvision 加载 MNIST 数据集;
- 定义一个卷积神经网络 CNN;
- 定义损失函数;
- 使用训练样本,训练网络;
- 在测试样本上进行测试。
MNIST 简介
MNIST 是深度学习领域中经典的手写图片数据集,这些图片采集自不同人手写的从 0 到 9 的数字,由 6 万张训练图片和 1 万张测试图片构成,每张图片都是 28*28 大小(单通道)。示例图片如下图所示:
MNIST 数据集由以下四个部分组成:
- 训练图片:
train-images-idx3-ubyte.gz
- 训练图片标签:
train-labels-idx1-ubyte.gz
- 测试图片:
t10k-images-idx3-ubyte.gz
- 测试图片标签:
t10k-labels-idx1-ubyte.gz
MNIST 数据集采用 ubyte 格式存储&