pytorch之torch.utils.data学习

1、概述

用于处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码分离,以获得更好的可读性和模块化性。 PyTorch 提供了两个数据原语:torch.utils.data.DataLoader和torch.utils.data.Dataset 允许您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其相应的标签,DataLoader围绕 Dataset进行迭代以方便访问样本。

Image Datasets

Torchvision 在模块中提供了许多内置数据集torchvision.datasets ,以及用于构建您自己的数据集的实用程序类。
所有数据集都是torch.utils.data.Dataset ie的子类,它们都__getitem__实现__len__了方法。因此,它们都可以传递给torch.utils.data.DataLoader 可以使用工作人员并行加载多个样本的程序torch.multiprocessing。例如:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

PyTorch 数据加载利用的核心是torch.utils.data.DataLoader类 。它表示在数据集上 Python 可迭代,支持

map-style and iterable-style datasets(地图样式和可迭代样式数据集),
customizing data loading order(自定义数据加载顺序),
automatic batching(自动批处理),
single- and multi-process data loading(单进程和多进程数据加载),
automatic memory pinning(自动内存固定)。
这些选项由 a 的构造函数参数配置 DataLoader,其签名为:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下部分详细描述了这些选项的效果和用法。

2、Dataset Types(数据类型)

DataLoader 构造函数最重要的参数是dataset,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

map-style datasets,(地图式数据集),
iterable-style datasets.(可迭代式数据集。)

map-style datasets 地图式数据集

映射式数据集是一种实现__getitem__()和 len()协议的数据集,并表示从(可能是非整数)索引/键到数据样本的映射。
例如,这样的数据集,当使用 访问时dataset[idx],可以从磁盘上的文件夹中读取第 idx-th个图像及其相应的标签。
请参阅Dataset了解更多详情。

iterable-style datasets可迭代式数据集

CLASStorch.utils.data.IterableDataset(*args, **kwds)

可迭代样式数据集是实现协议__iter__()的IterableDataset 子类的实例,并表示数据样本上的一个迭代迭代。这种类型的数据集特别适合随机读取成本昂贵甚至不可能的情况,以及批量大小取决于获取的数据的情况。

例如,这样的数据集在调用时iter(dataset)可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。
所有表示数据样本可迭代对象的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用
请参阅IterableDataset了解更多详情。

Example 1: splitting workload across all workers in iter():

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])

# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn:

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        return iter(range(self.start, self.end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)

# Mult-process loading with the custom `worker_init_fn`
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))

# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))

笔记

当使用多进程数据加载的IterableDataset时。在每个工作进程上复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复数据。请参阅IterableDataset文档了解如何实现这一点。

3、数据加载顺序和Sampler

对于可迭代样式的数据集,数据加载顺序完全由用户定义的可迭代控制。这允许更容易地实现块读取和动态批量大小(例如,通过每次生成批量样本)。

本节的其余部分涉及地图样式数据集的情况 。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的序列。它们表示数据集索引上的可迭代对象。例如,在随机梯度下降 (SGD) 的常见情况下,a Sampler可以随机排列一系列索引并一次生成每个索引,或者为小批量 SGD 生成少量索引。

A sequential or shuffled sampler将根据DataLoader 的shuffle参数自动构。或者,用户可以使用sampler参数来指定一个自定义Sampler对象,该对象每次都会生成下一个要获取的索引/键。

自定义Sampler一次生成批次索引列表,可以作为batch_sampler参数传递。自动批处理也可以通过batch_size和drop_last 参数启用。有关这方面的更多详细信息,请参阅 下一节。

注意

sampler和batch_sampler都不兼容可迭代风格的数据集,因为这样的数据集没有概念。

4、加载批量和非批量数据

DataLoader支持自动将单独获取的数据样本通过参数 batch_size、drop_last、batch_sampler和 collate_fn(具有默认功能)整理为批次

自动批处理(默认)

这是最常见的情况,对应于获取小批量数据并将它们整理成批量样本,即包含一个维度为批量维度(通常是第一个维度)的张量。

当batch_size(默认1)不是None时,数据加载器将生成批量样本而不是单个样本。batch_size和 drop_last参数用于指定数据加载器如何获取批量的数据集键。对于地图样式数据集,用户也可以指定batch_sampler,这一次会生成一个键列表。

在使用来自sampler的索引获取样本列表之后,使用作为collate_fn参数传递的函数将样本列表整理成批。
在这种情况下,从地图样式的数据集加载大致相当于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

从可迭代风格的数据集加载大致相当于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者简单地加载单个样本。例如,直接加载批处理数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批处理大小依赖于数据,或者程序被设计为处理单个样本。在这些场景下,最好不要使用自动批处理(其中collate_fn用于整理样本),而是让数据加载器直接返回数据集object的每个成员。

当batch_size和batch_sampler都为None时(batch_sampler的默认值已经为None),自动批处理被禁用。从数据集中获得的每个样本都使用作为collate_fn参数传递的函数进行处理。

当自动批处理被禁用时,默认的collate_fn只是将NumPy数组转换为PyTorch张量,并保持其他所有内容不变。
在这种情况下,从地图样式的数据集加载大致相当于:

for index in sampler:
    yield collate_fn(dataset[index])

从可迭代风格的数据集加载大致相当于:

for data in iter(dataset):
    yield collate_fn(data)

Working with collate_fn

例如,如果每个数据样本由一个3通道图像和一个整数类标签组成,也就是说,数据集的每个元素返回一个元组(image, class_index),默认collate_fn将这样的元组列表整理成一个批处理图像张量和一个批处理类标签张量的元组。特别是,默认的collate_fn具有以下属性:

5、Single- and Multi-process Data Loading

默认情况下,DataLoader使用单进程数据加载。

torch.utils.data.get_worker_info()返回工作进程中的各种有用信息(包括工作进程id,数据集副本,初始种子等),并在主进程中返回None。用户可以在数据集代码和/或worker_init_fn中使用这个函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有用。

6、内存固定

CLASStorch.utils.data.DataLoader

CLASStorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')

数据加载器组合了数据集和采样器,并提供了给定数据集的可迭代对象。
支持DataLoader地图样式和可迭代样式数据集,具有单进程或多进程加载、自定义加载顺序以及可选的自动批处理(整理)和内存固定。

Parameters
dataset (Dataset) – dataset from which to load the data.

batch_size (int, optional) – how many samples per batch to load (default: 1).

shuffle (bool, optional)set to True to have the data reshuffled at every epoch (default: False).

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.

batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). **Used when using batched loading from a map-style dataset**.

pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

drop_last (bool, optional)set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

timeout (numeric, optional)if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)

generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)

prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).

persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

CLASS torch.utils.data.Dataset(*args, **kwds)

CLASS torch.utils.data.IterableDataset(*args, **kwds)

可迭代的数据集。

所有表示数据样本可迭代的数据集都应该对其进行子类化。当数据来自流时,这种形式的数据集特别有用。

所有子类都应该覆盖__iter__(),这将返回此数据集中样本的迭代器。

CLASS torch.utils.data.TensorDataset(*tensors)

数据集包装张量。
每个样本将通过沿第一维度索引张量来检索。

Parameters

*tensors (Tensor) – tensors that have the same size of the first dimension.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值