Pytorch手写数字数据集MNIST下载与查看
MNIST数据集介绍
MNIST数据集是一个广泛使用的手写数字数据集,由美国国家标准与技术研究所(NIST)发起并整理。这个数据集包含了来自250个不同的人手写数字的图片,其中一半是高中生,另一半来自人口普查局的工作人员。主要目的是通过算法实现对手写数字的识别。
MNIST数据集一共包含了70000张图像,其中60000张用于训练,10000张用于测试。每张图像都是28×28像素的灰度图像,代表一个手写数字,范围从0到9。每张图像都附带一个标签,表示该图像上写的是哪个数字。
数据的下载与加载
运行代码,会在你指定的路径下下载数据集并且解压:
import torch
from torchvision import datasets, transforms
# 定义图像转换操作:将图像转换为张量,并进行归一化
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# 下载并加载训练数据集
trainset = datasets.MNIST('./MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# 下载并加载测试数据集
testset = datasets.MNIST('./MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
查看数据集
下载解压的数据集我们不能直观的查看,如果有需要,可以使用代码提取出图片。如下图所示:
import os
import torch
from torchvision import datasets, transforms
from PIL import Image
# 加载MNIST数据集,这次不应用任何转换,因为默认返回的就是PIL Image
mnist_train = datasets.MNIST(root='./MNIST_data/', train=True, download=True)
mnist_test = datasets.MNIST(root='./MNIST_data/', train=False, download=True)
# 定义保存图片和标签的目录
image_dir = 'mnist_images'
label_dir = 'mnist_labels'
# 确保保存目录存在
os.makedirs(image_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)
# 遍历训练数据集并保存图片和标签
for idx, (image, label) in enumerate(mnist_train):
# 保存图片
image_path = os.path.join(image_dir, f'{idx:05d}.png') # 使用5位数序号命名图片文件
image.save(image_path)
# 保存标签
label_path = os.path.join(label_dir, f'{idx:05d}.txt')
with open(label_path, 'w') as f:
f.write(f'{idx:05d} {label}\n') # 文件内容是图片序号和对应标签
# 遍历测试数据集并保存图片和标签
for idx, (image, label) in enumerate(mnist_test, start=len(mnist_train)):
# 保存图片
image_path = os.path.join(image_dir, f'{idx:05d}.png') # 使用5位数序号命名图片文件
image.save(image_path)
# 保存标签
label_path = os.path.join(label_dir, f'{idx:05d}.txt')
with open(label_path, 'w') as f:
f.write(f'{idx:05d} {label}\n') # 文件内容是图片序号和对应标签
print("图片和标签保存完成!")