Pytorch加载MNIST数据集
MNIST数据集是入门深度学习的一个重要的数据集。由National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)发布,发布时间为1998年。
数据集详细介绍:
MNIST数据集有10个类别,分别代表0-9之间的数字。共有60000张图像作为训练集,10000张图像作为测试集。官方地址:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burgeshttp://yann.lecun.com/exdb/mnist/下面我们用Pytorch来加载数据集
准备工作
推荐先下载好MNIST数据集。
我是用别人提供的博客上的数据集下载的。
解压后的数据集如下
dataset_uncompressed/
├── t10k-images-idx3-ubyte #测试集图像数据
├── t10k-labels-idx1-ubyte #测试集标签数据
├── train-images-idx3-ubyte #训练集图像数据
└── train-labels-idx1-ubyte #训练集标签数据
Pytorch加载该数据集
第一步导入相关的Python包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
主要是三个包
torch主要使用dataloader类。
torchvision内置了很多图像处理的函数,是计算机视觉的最重要的一个库。主要用来图像处理和读取数据集
d2l是沐神写的Pytorch的部分代码,主要用于方便展示数据集。
首先利用d2l
d2l.use_svg_display()来使用svg格式展示数据集
读取的代码如下
打印一下数据集的len
和之前说的一样训练集共60000张图片,测试集共10000张图片
我们看一下其中的一个图像,是28*28像素大小
展示一下数据集
利用pytorch的DataLoader可以非常方便的加载数据集。我们定义每次加载32个图片。
下面再使用next和iter把它变成一个迭代器
利用matplotlib展示一下这些图片