什么是 MNIST数据集

MNIST 数据集简介

MNIST(Modified National Institute of Standards and Technology)数据集是机器学习和深度学习中最常用的经典数据集之一。它包含 28×28 的灰度手写数字图片,涵盖了从 0 到 9 的 10 个类别。MNIST 数据集广泛用于图像分类、机器学习算法的验证和深度学习模型的初学训练。


1. 数据集特性

  • 样本数量
    • 训练集:60,000 张图片
    • 测试集:10,000 张图片
  • 图像大小
    • 每张图片为 28 × 28 28 \times 28 28×28 像素的灰度图像(单通道)。
  • 标签
    • 每张图片对应一个数字,范围为 0 到 9。

2. 加载 MNIST 数据集

2.1 使用 PyTorch 加载

PyTorch 提供了 torchvision.datasets 模块,可以非常方便地加载 MNIST 数据集。

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义数据转换操作
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 PyTorch 张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

# 下载并加载训练集和测试集
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform,
    download=True
)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

2.2 参数解析
  • root:指定数据集的存储路径。如果路径下没有数据集,会自动下载。
  • train
    • True:加载训练数据集。
    • False:加载测试数据集。
  • transform
    • 用于对数据进行预处理,例如转换为张量、归一化等。
  • download:如果数据集未下载,设置为 True 自动下载。
  • batch_size:每次加载的数据数量。
  • shuffle
    • True:每次迭代时随机打乱数据。
    • False:按顺序加载数据。

3. MNIST 数据集的结构

加载后的数据集是一个包含图片和标签的集合,每个样本包含以下两部分:

  1. 输入数据(图片)
    • 28 × 28 28 \times 28 28×28 的灰度值矩阵,范围为 [0, 255]。
    • 使用 transforms.ToTensor() 转换后,范围归一化为 [0, 1]。
  2. 标签(数字)
    • 每张图片的数字标签(0 到 9)。

4. 查看数据集内容

import matplotlib.pyplot as plt

# 从数据集中取出一个样本
image, label = train_dataset[0]

# 显示图片和标签
plt.imshow(image.squeeze(), cmap='gray')  # 去掉通道维度并以灰度图显示
plt.title(f"Label: {label}")
plt.axis('off')
plt.show()

5. 简单模型训练示例

以下是一个用 PyTorch 训练 MNIST 数据集的简单例子:

import torch.nn as nn
import torch.optim as optim

# 定义一个简单的全连接神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),  # 展平输入
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)  # 输出10个类别
        )

    def forward(self, x):
        return self.fc(x)

# 创建模型
model = SimpleNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
epochs = 5
for epoch in range(epochs):
    model.train()
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")

6. 扩展 MNIST 的变体

  • Fashion-MNIST

    • 包含 10 类服装物品的灰度图像,尺寸为 28 × 28 28 \times 28 28×28,类似于 MNIST。
    • 可通过 torchvision.datasets.FashionMNIST 加载。
  • EMNIST

    • 包含扩展的手写字符数据,包括字母和数字。
  • KMNIST

    • 日文字符数据集,类似于 MNIST。

总结

MNIST 是一个经典的数据集,具有以下特点:

  • 简单易用,适合初学者入门。
  • 可用于验证分类模型的效果。
  • 支持快速实现网络的训练和测试。

通过 MNIST 数据集,你可以熟悉深度学习的基础流程,如数据加载、模型定义、训练和评估。对于更复杂的任务,可以扩展到其他更大的数据集,如 CIFAR 或 ImageNet。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值