Pytorh学习——DataSet和DataLoader

目录

Pytorch的数据集

DataSet

DataLoader

创建自定义数据集

参考文档


Pytorch的数据集

Pytorch深度学习库以一种可读性强、模块化程度高的方式来构建深度学习网络。在构建深度学习网络时,数据的加载和预处理是一项重要而繁琐的工作。如果在构建网络中, 我们需要为加载样本数据、样本数据预处理编写大量的处理代码,会导致代码变得混乱、网络构建过程不清晰,最终难以维护。

基于以上考虑,Pytorch将数据集和数据集的加载定义为两个单独对象,使数据集代码和模型训练代码相分离,以获得更好的可读性和模块化。

Pytorch提供了两个DataSet和DataLoader两个类。

DataSet

DataSet是数据集对象类, Pytorch提供了大量的默认数据集, 包括Fashion-MINST、CIFAR-10、CIFAR-100、CelebA等数据集。如果用户想要加载自定义的数据只需要继承DataSet类。

Pytorch支持两种类型的DataSet:

  • Map类型DataSet
  • Iterable类型DataSet

Map类型DataSet

Map类型DataSet实现__getitem__()和 __len__(),表示从索引/键到数据样本的映射。数据集在使用 访问时,可以通过索引直接获取相关样本数据。例如,dataset[idx]表示使用idx从磁盘上的文件夹中读取第i个图像及其相应的标签。

Iterable类型DataSet

IterableDataset 实现了__iter__()函数,可对数据样本进行迭代访问。这种类型的数据集特别适用于随机读取代价高昂以及批量大小取决于获取的数据等场景。

例如,在从数据库、远程服务器甚至实时生成的日志中读取的数据流场景中,可以使用iter(dataset)来访问数据

代码示例

import torch
from torchvision import datasets

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

上面示例中,会下载Pytorch提供的数据集FashionMNIST。参数说明:

  • root:  数据集文件路径,DataSet会到该目录下查找相关数据集文件。
  • train: 是否是训练数据。 True:表示训练数据; False: 测试数据。如果为True,则下载训练数据集;如果为False,则下载测试数据集。
  • download: 当本地数据集不存在时, 是否远程下载到本地。
  • transform: 数据转换操作, ToTensor()将数据转换为张量。

DataLoader

Dataloader 是一个迭代器,最基本的使用就是传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据.

DataLoader函数

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

常用参数说明:

  • dataset: 需要加载的数据集
  • batch_size: 每个迭代返回的样本数
  • shuffle: 如果为True,则每次epoch时对数据进行shuffle操作。

示例代码

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))

上面代码中,DataLoader封装training_data、test_data为迭代器。 通过train_dataloader迭代获取样本数据,每个迭代获取64个(batch_size)个样本。

加载Fashion-MNIST数据集

下例为Pytorch关于DataSet和DataLoader的官方示例

该示例是从 TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 的文章图像数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载FashionMNIST 数据集

  • root :存储训练/测试数据的路径,
  • train :指定训练或测试数据集,
  • download=True:如果数据不可用,则下载数据root
  • transform&target_transform:指定特征和标签转换

示例代码

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()


train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。每次迭代都会返回一批(64个样本)train_featurestrain_labelsbatch_size=64分别包含特征和标签)。因为指定了shuffle=True,在遍历所有批次后,数据会被打乱。

运行结果

../../_images/sphx_glr_data_tutorial_001.png

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 0

参考文档

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: DatasetDataLoaderPyTorch 中用于加载和处理数据的两个主要组件。Dataset 用于从数据源中提取和加载数据,DataLoader 则用于将数据转换为适合机器学习模型训练的格式。 ### 回答2: 在PyTorch中,DatasetDataLoader是用于处理和加载数据的两个重要类。 Dataset是一个抽象类,用于表示数据集对象。我们可以自定义Dataset子类来处理我们自己的数据集。通过继承Dataset类,我们需要实现两个主要方法: - __len__()方法:返回数据集的大小(样本数量) - __getitem__(idx)方法:返回索引为idx的样本数据 使用Dataset类的好处是可以统一处理训练集、验证集和测试集等不同的数据集,将数据进行一致的格式化和预处理。 DataLoader是一个实用工具,用于将Dataset对象加载成批量数据。数据加载器可以根据指定的批大小、是否混洗样本和多线程加载等选项来提供高效的数据加载方式。DataLoader是一个可迭代对象,每次迭代返回一个批次的数据。我们可以通过循环遍历DataLoader对象来获取数据。 使用DataLoader可以实现以下功能: - 数据批处理:将数据集划分为批次,并且可以指定每个批次的大小。 - 数据混洗:可以通过设置shuffle选项来随机打乱数据集,以便更好地训练模型。 - 并行加载:可以通过设置num_workers选项来指定使用多少个子进程来加载数据,加速数据加载过程。 综上所述,DatasetDataLoaderPyTorch中用于处理和加载数据的两个重要类。Dataset用于表示数据集对象,我们可以自定义Dataset子类来处理我们自己的数据集。而DataLoader是一个实用工具,用于将Dataset对象加载成批量数据,提供高效的数据加载方式,支持数据批处理、数据混洗和并行加载等功能。 ### 回答3: 在pytorch中,Dataset是一个用来表示数据的抽象类,它封装了数据集的访问方式和数据的获取方法。Dataset类提供了读取、处理和转换数据的功能,可以灵活地处理各种类型的数据集,包括图像、语音、文本等。用户可以继承Dataset类并实现自己的数据集类,根据实际需求定制数据集。 Dataloader是一个用来加载数据的迭代器,它通过Dataset对象来获取数据,并按照指定的batch size进行分批处理。Dataloader可以实现多线程并行加载数据,提高数据读取效率。在训练模型时,通常将Dataset对象传入Dataloader进行数据加载,并通过循环遍历Dataloader来获取每个batch的数据进行训练。 DatasetDataloader通常配合使用,Dataset用于数据的读取和预处理,Dataloader用于并行加载和分批处理数据。使用DatasetDataloader的好处是可以轻松地处理大规模数据集,实现高效的数据加载和预处理。此外,DatasetDataloader还提供了数据打乱、重复采样、数据划分等功能,可以灵活地控制数据的访问和使用。 总之,DatasetDataloaderpytorch中重要的数据处理模块,它们提供了方便的接口和功能,用于加载、处理和管理数据集,为模型训练和评估提供了便利。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值