MINST数据集处理

如果你下载的MNIST数据集没有processed目录下的training.pttest.pt文件,而是有.gz格式的原始数据文件,那么你需要手动处理这些文件以生成training.pttest.pt文件,或者直接从原始数据文件中加载数据。
下面是一个示例代码,展示了如何从原始的.gz文件中读取MNIST数据集,并将其转换为PyTorch的Tensor格式:

import torch
from torchvision import datasets, transforms
import os
import gzip
import numpy as np
# 定义转换操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# 定义函数从原始的.gz文件中读取数据
def load_mnist_images(filename):
    with gzip.open(filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
        data = data.reshape(-1, 28*28)
        return torch.tensor(data, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0
def load_mnist_labels(filename):
    with gzip.open(filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=8)
        return torch.tensor(data, dtype=torch.long)
# 指定MNIST数据集的本地路径
root = 'data/MNIST'
# 读取训练数据和标签
train_images = load_mnist_images(os.path.join(root, 'train-images-idx3-ubyte.gz'))
train_labels = load_mnist_labels(os.path.join(root, 'train-labels-idx1-ubyte.gz'))
# 读取测试数据和标签
test_images = load_mnist_images(os.path.join(root, 't10k-images-idx3-ubyte.gz'))
test_labels = load_mnist_labels(os.path.join(root, 't10k-labels-idx1-ubyte.gz'))
# 创建自定义数据集
class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]
# 创建训练和测试数据集
train_dataset = MNISTDataset(train_images, train_labels)
test_dataset = MNISTDataset(test_images, test_labels)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 使用数据加载器
for images, labels in train_loader:
    # 这里可以添加你的训练或验证代码
    pass

这段代码会从原始的.gz文件中读取数据,并创建自定义的MNISTDataset类,然后使用DataLoader进行数据加载。这样,你就不需要预先处理数据为.pt文件了。

以下是对上述代码中定义的两个函数load_mnist_imagesload_mnist_labels的详细解释,说明为什么要进行这些步骤:

load_mnist_images 函数解释:

  1. 打开.gz文件:
    with gzip.open(filename, 'rb') as f:
    
    使用gzip.open以二进制读取模式打开.gz文件。.gz是一个压缩格式,用于减小文件大小,提高存储和传输效率。
  2. 读取并跳过文件头部:
    data = np.frombuffer(f.read(), np.uint8, offset=16)
    
    f.read()读取整个文件内容,np.frombuffer将字节字符串转换为NumPy数组。offset=16是为了跳过文件头部,因为MNIST图像文件的前16个字节是文件头信息,不包含图像数据。
  3. 数据形状调整:
    data = data.reshape(-1, 28*28)
    
    reshape(-1, 28*28)将一维数组转换为二维数组,其中-1表示自动计算行数,28*28表示每行有784个元素(即一个28x28像素的图像展平后的长度)。
  4. 转换为PyTorch张量并标准化:
    return torch.tensor(data, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0
    
    torch.tensor将NumPy数组转换为PyTorch张量,并指定数据类型为torch.float32.view(-1, 1, 28, 28)将张量重新排列为所需的四维形状(批量大小,通道数,高度,宽度)。最后,除以255.0将像素值从0-255的范围标准化到0-1的范围。

load_mnist_labels 函数解释:

  1. 打开.gz文件:
    with gzip.open(filename, 'rb') as f:
    
    同样使用gzip.open以二进制读取模式打开.gz文件。
  2. 读取并跳过文件头部:
    data = np.frombuffer(f.read(), np.uint8, offset=8)
    
    读取整个文件内容,并将字节字符串转换为NumPy数组。offset=8用于跳过文件头部,因为MNIST标签文件的前8个字节是文件头信息。
  3. 转换为PyTorch张量:
    return torch.tensor(data, dtype=torch.long)
    
    torch.tensor将NumPy数组转换为PyTorch张量,并指定数据类型为torch.long,因为标签是整数。

为什么这样做:

  • 直接从.gz文件读取:这样做是因为你已经有.gz格式的原始数据文件,而不是已经处理好的.pt文件。
  • 跳过文件头:MNIST数据集的文件格式包含一个头部,它描述了数据集的一些元信息,如魔数、图片数量等,这些信息在读取图像或标签数据时不需要。
  • 数据形状和类型转换:将数据转换为正确的形状和类型是进行深度学习模型训练的必要步骤。PyTorch要求输入数据的维度和类型与模型期望的输入相匹配。
  • 标准化:标准化输入数据是深度学习中的一个常见预处理步骤,它有助于模型更好地学习。通过将像素值标准化到0-1范围,可以加快收敛速度并提高模型性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值