MNIST 数据集简介
MNIST(Modified National Institute of Standards and Technology)数据集是机器学习和深度学习中最常用的经典数据集之一。它包含 28×28 的灰度手写数字图片,涵盖了从 0 到 9 的 10 个类别。MNIST 数据集广泛用于图像分类、机器学习算法的验证和深度学习模型的初学训练。
1. 数据集特性
- 样本数量:
- 训练集:60,000 张图片
- 测试集:10,000 张图片
- 图像大小:
- 每张图片为 28 × 28 28 \times 28 28×28 像素的灰度图像(单通道)。
- 标签:
- 每张图片对应一个数字,范围为 0 到 9。
2. 加载 MNIST 数据集
2.1 使用 PyTorch 加载
PyTorch 提供了 torchvision.datasets
模块,可以非常方便地加载 MNIST 数据集。
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 定义数据转换操作
transform = transforms.Compose([
transforms.ToTensor(), # 转换为 PyTorch 张量
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 下载并加载训练集和测试集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transform,
download=True
)
# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
2.2 参数解析
root
:指定数据集的存储路径。如果路径下没有数据集,会自动下载。train
:True
:加载训练数据集。False
:加载测试数据集。
transform
:- 用于对数据进行预处理,例如转换为张量、归一化等。
download
:如果数据集未下载,设置为True
自动下载。batch_size
:每次加载的数据数量。shuffle
:True
:每次迭代时随机打乱数据。False
:按顺序加载数据。
3. MNIST 数据集的结构
加载后的数据集是一个包含图片和标签的集合,每个样本包含以下两部分:
- 输入数据(图片):
- 28 × 28 28 \times 28 28×28 的灰度值矩阵,范围为 [0, 255]。
- 使用
transforms.ToTensor()
转换后,范围归一化为 [0, 1]。
- 标签(数字):
- 每张图片的数字标签(0 到 9)。
4. 查看数据集内容
import matplotlib.pyplot as plt
# 从数据集中取出一个样本
image, label = train_dataset[0]
# 显示图片和标签
plt.imshow(image.squeeze(), cmap='gray') # 去掉通道维度并以灰度图显示
plt.title(f"Label: {label}")
plt.axis('off')
plt.show()
5. 简单模型训练示例
以下是一个用 PyTorch 训练 MNIST 数据集的简单例子:
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的全连接神经网络
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc = nn.Sequential(
nn.Flatten(), # 展平输入
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10) # 输出10个类别
)
def forward(self, x):
return self.fc(x)
# 创建模型
model = SimpleNN()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
epochs = 5
for epoch in range(epochs):
model.train()
for images, labels in train_loader:
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
6. 扩展 MNIST 的变体
-
- 包含 10 类服装物品的灰度图像,尺寸为 28 × 28 28 \times 28 28×28,类似于 MNIST。
- 可通过
torchvision.datasets.FashionMNIST
加载。
-
EMNIST:
- 包含扩展的手写字符数据,包括字母和数字。
-
KMNIST:
- 日文字符数据集,类似于 MNIST。
总结
MNIST 是一个经典的数据集,具有以下特点:
- 简单易用,适合初学者入门。
- 可用于验证分类模型的效果。
- 支持快速实现网络的训练和测试。
通过 MNIST 数据集,你可以熟悉深度学习的基础流程,如数据加载、模型定义、训练和评估。对于更复杂的任务,可以扩展到其他更大的数据集,如 CIFAR 或 ImageNet。