如果你下载的MNIST数据集没有processed
目录下的training.pt
和test.pt
文件,而是有.gz
格式的原始数据文件,那么你需要手动处理这些文件以生成training.pt
和test.pt
文件,或者直接从原始数据文件中加载数据。
下面是一个示例代码,展示了如何从原始的.gz
文件中读取MNIST数据集,并将其转换为PyTorch的Tensor
格式:
import torch
from torchvision import datasets, transforms
import os
import gzip
import numpy as np
# 定义转换操作
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 定义函数从原始的.gz文件中读取数据
def load_mnist_images(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 28*28)
return torch.tensor(data, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0
def load_mnist_labels(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
return torch.tensor(data, dtype=torch.long)
# 指定MNIST数据集的本地路径
root = 'data/MNIST'
# 读取训练数据和标签
train_images = load_mnist_images(os.path.join(root, 'train-images-idx3-ubyte.gz'))
train_labels = load_mnist_labels(os.path.join(root, 'train-labels-idx1-ubyte.gz'))
# 读取测试数据和标签
test_images = load_mnist_images(os.path.join(root, 't10k-images-idx3-ubyte.gz'))
test_labels = load_mnist_labels(os.path.join(root, 't10k-labels-idx1-ubyte.gz'))
# 创建自定义数据集
class MNISTDataset(torch.utils.data.Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
# 创建训练和测试数据集
train_dataset = MNISTDataset(train_images, train_labels)
test_dataset = MNISTDataset(test_images, test_labels)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 使用数据加载器
for images, labels in train_loader:
# 这里可以添加你的训练或验证代码
pass
这段代码会从原始的.gz
文件中读取数据,并创建自定义的MNISTDataset
类,然后使用DataLoader
进行数据加载。这样,你就不需要预先处理数据为.pt
文件了。
以下是对上述代码中定义的两个函数load_mnist_images
和load_mnist_labels
的详细解释,说明为什么要进行这些步骤:
load_mnist_images
函数解释:
- 打开.gz文件:
使用with gzip.open(filename, 'rb') as f:
gzip.open
以二进制读取模式打开.gz
文件。.gz
是一个压缩格式,用于减小文件大小,提高存储和传输效率。 - 读取并跳过文件头部:
data = np.frombuffer(f.read(), np.uint8, offset=16)
f.read()
读取整个文件内容,np.frombuffer
将字节字符串转换为NumPy数组。offset=16
是为了跳过文件头部,因为MNIST图像文件的前16个字节是文件头信息,不包含图像数据。 - 数据形状调整:
data = data.reshape(-1, 28*28)
reshape(-1, 28*28)
将一维数组转换为二维数组,其中-1
表示自动计算行数,28*28
表示每行有784个元素(即一个28x28像素的图像展平后的长度)。 - 转换为PyTorch张量并标准化:
return torch.tensor(data, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0
torch.tensor
将NumPy数组转换为PyTorch张量,并指定数据类型为torch.float32
。.view(-1, 1, 28, 28)
将张量重新排列为所需的四维形状(批量大小,通道数,高度,宽度)。最后,除以255.0将像素值从0-255的范围标准化到0-1的范围。
load_mnist_labels
函数解释:
- 打开.gz文件:
同样使用with gzip.open(filename, 'rb') as f:
gzip.open
以二进制读取模式打开.gz
文件。 - 读取并跳过文件头部:
读取整个文件内容,并将字节字符串转换为NumPy数组。data = np.frombuffer(f.read(), np.uint8, offset=8)
offset=8
用于跳过文件头部,因为MNIST标签文件的前8个字节是文件头信息。 - 转换为PyTorch张量:
return torch.tensor(data, dtype=torch.long)
torch.tensor
将NumPy数组转换为PyTorch张量,并指定数据类型为torch.long
,因为标签是整数。
为什么这样做:
- 直接从.gz文件读取:这样做是因为你已经有.gz格式的原始数据文件,而不是已经处理好的.pt文件。
- 跳过文件头:MNIST数据集的文件格式包含一个头部,它描述了数据集的一些元信息,如魔数、图片数量等,这些信息在读取图像或标签数据时不需要。
- 数据形状和类型转换:将数据转换为正确的形状和类型是进行深度学习模型训练的必要步骤。PyTorch要求输入数据的维度和类型与模型期望的输入相匹配。
- 标准化:标准化输入数据是深度学习中的一个常见预处理步骤,它有助于模型更好地学习。通过将像素值标准化到0-1范围,可以加快收敛速度并提高模型性能。