Fashion-MNIST 数据集
Fashion-MNIST 是一个替代经典 MNIST 的数据集,旨在更加贴近实际应用场景。它由 Zalando 提供,包含了 10 类服装物品的灰度图像,与 MNIST 数据集的结构相似。Fashion-MNIST 常用于图像分类和深度学习模型的测试和基准评估。
1. 数据集特性
- 样本数量:
- 训练集:60,000 张图片
- 测试集:10,000 张图片
- 图像大小:
- 每张图片为 28 × 28 28 \times 28 28×28 像素的灰度图像(单通道)。
- 标签类别:
- 共有 10 个类别,标签值范围为
0-9
,对应的类别名称如下:标签 类别名称 0 T-shirt/top 1 Trouser 2 Pullover 3 Dress 4 Coat 5 Sandal 6 Shirt 7 Sneaker 8 Bag 9 Ankle boot
- 共有 10 个类别,标签值范围为
2. 加载 Fashion-MNIST 数据集
使用 PyTorch 加载
PyTorch 提供了 torchvision.datasets.FashionMNIST
,可以方便地加载 Fashion-MNIST 数据集。
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 数据转换操作
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 下载并加载训练集和测试集
train_dataset = torchvision.datasets.FashionMNIST(
root='./data',
train=True,
transform=transform,
download=True
)
test_dataset = torchvision.datasets.FashionMNIST(
root='./data',
train=False,
transform=transform,
download=True
)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
3. 数据集内容可视化
可以用 Matplotlib 可视化数据集中的图片和标签:
import matplotlib.pyplot as plt
# 获取一个样本
image, label = train_dataset[0]
# 类别名称
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 显示图片
plt.imshow(image.squeeze(), cmap='gray') # 去掉通道维度并显示为灰度图
plt.title(f"Label: {classes[label]}")
plt.axis('off')
plt.show()
4. 简单模型训练示例
以下是一个简单的用 PyTorch 训练 Fashion-MNIST 的分类模型示例:
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的卷积神经网络
class FashionCNN(nn.Module):
def __init__(self):
super(FashionCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv_layers(x)
x = self.fc_layers(x)
return x
# 初始化模型
model = FashionCNN()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
epochs = 5
for epoch in range(epochs):
model.train()
for images, labels in train_loader:
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
5. 扩展与应用
数据增强
可以使用 torchvision.transforms
增强数据,提高模型的泛化能力。
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
模型改进
- 使用更深的卷积神经网络(如 ResNet)。
- 引入 Dropout 和 Batch Normalization 提高模型性能。
学习迁移
- 使用预训练模型(如 VGG、ResNet),将输入图像扩展到 224 × 224 224 \times 224 224×224。
6. 与 MNIST 的区别
特性 | MNIST | Fashion-MNIST |
---|---|---|
数据类型 | 手写数字 | 服装图片 |
类别数 | 10 | 10 |
图像复杂度 | 较低 | 较高 |
应用场景 | 数字分类问题 | 真实世界的多分类问题 |
模型泛化能力测试 | 较低 | 较高 |
总结
Fashion-MNIST 是一个更加贴近实际问题的分类数据集,它的特点包括:
- 提供了较高的图像复杂度,适合测试模型的泛化能力。
- 数据结构与 MNIST 相似,适合入门学习和快速实验。
使用 Fashion-MNIST,可以更好地理解图像分类任务和深度学习模型的基本工作原理。对于更高性能的应用,可以尝试引入更复杂的模型或预训练网络。