pytorch.data模块文档翻译


Pytorch数据导入功能中核心模块是 torch.utils.data.DataLoader 类。 其构建了一个递推读取数据的函数,具有如下功能:

上述特性可以通过设置dataloader的输入参数实现

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

dataset types

  • map-style datasets:
  • iterable-style datasets

pytorch提供两种类型的数据集接口,一种是映射形式的,另外一种是递推的。

map-style datasets

映射形式数据集实现__getitem__以及__len__函数,实现了从下标键值到数据样本之间的映射。

例如,在图像领域,它可能是这样一个数据集读取模式dataset[idx]即读取了第idx-个图像对象,以及其对应的标签。

iterable-style datasets

递推数据集是IterableDataset 类的实例,主要实现了__iter__()函数,可以递推式读取数据。这种形式的数据集在数据成流式到达时适用。

例如一个这样的数据集可以调用iter(dataset)来从数据集或者远程服务器,甚至实时生成的数据中返回一个数据流。

数据载入次序与Sampler

对于iterable-style datasets ,数据载入的次序完全由用户控制。

余下本节主要讲述map-style datasets的机制。torch.utils.data.Sampler 类用于指定以什么样的次序载入数据。例如在使用SGD训练网络时,Sampler 可以每次随机(也可以不随机,顺序读取)的从数据集读取一个样本,或者在mini-batch SGD训练策略中读取一小批数据。

dataloader中的shuffle参数可以指定是采用顺序的还是随机的方式读取数据。用户也可以通过指定sampler参数值来控制采样方式。

一个定制的Sampler每次生成一批索引的列表,可以作为batch_sampler参数。还可以通过batch_size和drop_last参数启用自动批处理。有关这方面的更多细节,请参见下一节。

NOTE: iterable-style datasets中没有sampler,batch_sampler函数,因为这种数据集中没有键值或者索引

载入批或者非批数据

DataLoader支持通过参数batch_size、drop_last和batch_sampler将单个获取的数据样本自动整理成批。

自动批处理

这是最常见的情况,对应于获取一个小批数据并将其整理成成批样本,即,包含一个维度为批处理维度(通常是第一个维度)的张量。

当batch_size(默认为1)非none, dataloader将生成一批,而非单个样本。batch_size跟drop_last变量用于控制每一批数据大小以及是否丢弃最后一批大小小于batch_size的数据样本。

note: batch_size和drop_last参数本质上用于从sampler构造batch_sampler。对于map样式的数据集,采样器要么由用户提供,要么基于shuffle参数构造。对于迭代式数据集,采样器是一个虚拟的无穷大数据集。有关采样器的更多细节,请参阅本节。

note: 当从多进程递推数据集中采样数据时,drop_last参数将会把每一个进程中最后一个非完整(未达到batch_size大小)的数据批次丢弃。也就是说对每个进程批采样过程是独立计算的。

这些使用sampler采样器获得的数据样本索引之后将作为参数传递到collate_fn函数中,没错,collate_fn是一个函数句柄,它的功用是将数据样本排序成批。

在map风格数据集中,这个过程大致如下:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

在iterable风格数据集中,该过程大致如下:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以通过修改collate_fn实现诸如在一个非完整batch中使用padding操作填充数据。

关闭自动批处理

在某些情况下,用户可能希望在dataset代码中手动处理批处理,或者只加载单个样本。例如,直接加载成批数据(例如,从数据库批量读取或读取连续的内存块)可能更便宜,或者批大小依赖于数据,或者程序设计用于处理单个样本。在这些场景下,最好不要使用自动批处理(其中collate_fn用于对样本进行排序),而是让数据加载器直接返回dataset对象的每个成员。

当batch_size和batch_sampler都为None (batch_sampler的默认值已经为None)时,将禁用自动批处理。使用作为collate_fn参数传递的函数处理从数据集获得的每个样本。

当自动批处理被禁用后,默认的collate_fn只是简单的将原始numpy array数据转换为Tensors格式,未作其他改动。

在这种情况下,从map样式的数据集加载大致相当于:

for index in sampler:
    yield collate_fn(dataset[index])

从iterable样式数据集中加载大致相当于:

for data in iter(dataset):
    yield collate_fn(data)
collate_fn

启用或禁用自动批处理时,collate_fn的使用略有不同。

当禁用自动批处理时,将对每个单独的数据样本调用collate_fn,并从数据加载器dataloader迭代器生成输出。在本例中,默认的collate_fn只只是将NumPy数组转换为PyTorch张量。

当启用自动批处理时,每次调用collate_fn时都带有一个数据样本列表。希望它将输入样本整理成一批,以便从数据加载器迭代器生成。本节的其余部分将讲述这种默认collate_fn的行为。

例如,如果每个数据样本由一个3通道图像和一个整数类标签组成,即,数据集的每个元素返回一个元组(image, class_index),默认的collate_fn将这样的元组列表整理成一个由成批图像张量和成批类标签张量组成的单元组。特别是,默认的collate_fn具有以下属性:

  • 它总是预先添加一个新的维度作为批处理维度。
  • 它自动将NumPy数组和Python数值转换为PyTorch张量。
  • 它保留了数据结构,例如,如果每个样本都是一个字典dict,那么它输出的字典具有相同的键值集,但是将张量tensor作为值(如果不能将值转换为张量,则输出列表)。列表list、元组tuple、命名元组namedtuple也是如此。

用户可以使用定制的collate_fn来实现指定的批处理,例如,沿着第一个维度以外的维度进行排序,填充padding不同长度的序列,或者添加对定制数据类型的支持。

Single- and Multi-process Data Loading

dataloader默认使用单进程数据载入模式。

在Python进程中,全局解释器锁(GIL)限制了真正的跨线程完全并行化Python代码。为了避免使用数据加载阻塞计算代码,PyTorch提供了一个简单的开关来执行多进程数据加载,只需将参数num_workers设置为正整数。

Single-process data loading (default)

在此模式下,数据获取与初始化DataLoader的过程相同。因此,数据加载可能会阻塞计算。但是,当用于在进程之间共享数据的资源(例如共享内存、文件描述符)有限时,或者当整个数据集很小并且可以完全装入内存时,这种模式可能是首选的。此外,单进程加载通常显示更多可读的错误跟踪,因此对调试非常有用。

Multi-process data loading

将参数num_workers设置为正整数将使用num_workers个多进程进行数据加载。

在这种模式下,每次创建DataLoader的迭代器(例如,调用enumerate(DataLoader))时,都会创建num_workers个进程。此时,dataset、collate_fn和worker_init_fn被传递给每个worker,用于初始化和获取数据。这意味着数据集访问及其内部IO、转换(包括collate_fn)都在这个worker进程中运行。

torch.utils.data.get_worker_info()在工作进程中返回各种有用的信息(包括工作id、数据集副本、初始种子等等),在主进程中不返回任何信息。用户可以在dataset代码和/或worker_init_fn中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这对于数据集分片特别有用。

对于map样式的数据集,主进程使用采样器生成索引并将其发送给worker。因此,任何随机打乱过程都是在主进程中完成的,主进程通过为load分配索引来引导各子进程数据载入。

对于iterable风格的数据集,由于每个工作进程都获得dataset对象的副本,所以简单的多进程加载常常会导致重复的数据。使用torch.utils.data.get_worker_info()和/或worker_init_fn,用户可以独立配置每个副本。(有关如何实现这一点,请参阅IterableDataset文档。)出于类似的原因,在多进程加载中,drop_last参数将会删除每个worker的iterable样式数据集副本的最后一个非完整批数据。

一旦到达迭代的末尾,或者迭代器变为垃圾回收后,worker就会被关闭。

Warning: 通常不建议在多进程加载时返回CUDA张量,因为在多进程中使用CUDA和共享CUDA张量有很多微妙之处(参见多进程中的CUDA CUDA in multiprocessing)。相反,我们建议使用自动内存固定(即,设置pin_memory=True),使数据能够快速传输到支持cuda的gpu。

Randomness in multi-process data loading

默认情况下,每个worker都将其PyTorch种子设置为base_seed + worker_id,其中base_seed是由使用其RNG的主进程生成的长进程(因此强制使用RNG状态)。但是,其他库的种子可以在初始化worker时复制(w.g.Numpy)。,导致每个worker返回相同的随机数。(参见FAQ.中的这一节 this section)。

worker_init_fn,可以使用 torch.utils.data.get_worker_info().seed or torch.initial_seed() 来查看每个worker中的随机种子集,也可以用其在载入数据之前设置其他库的种子。

Memory Pinning

当GPU副本来自固定(页面锁定)内存时,主机到GPU副本的速度要快得多。有关何时以及如何使用固定内存的详细信息,请参阅使用固定内存缓冲区 Use pinned memory buffers

对于数据加载,把pin_memory=True传递给DataLoader,自动将获取的数据张量放入固定内存中,从而能够更快地将数据传输到支持cuda的gpu。

默认的内存固定逻辑只识别张量Tensor、映射map和包含张量的迭代器。默认情况下,如果pinning logic看到了一个batch数据类型是自定义的(如果collate_fn返回来一个自定义batch类型时会发生上述情况),或者如果batch的每个元素是一个自定义类型,pinning logic将无法识别他们,它会不做pinning memory处理的返回这批(或这些元素)。要为自定义batch或数据类型启用内存固定,请在自定义类型上定义pin_memory()方法。

例子如下:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

class

CLASS

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

数据载入器。由数据集与采样器组合而成,对指定数据集生成一个迭代采样器。

  • 参数:
    • sampler:定义从数据集中采样的方式,如果启用该布尔值则shuffle应置为False
  • note: len(dataloader)取决于所使用的采样器长度。由于IterableDataset使用的是无穷长的采样器,其实际长度取决于迭代过程及多进程,因此__len__()函数没有实现,用这种数据集类型时就不要调用len()函数了,肯定会报错的

dataloader 是一个可迭代对象,使用iter()访问,不能使用next()

CLASS

torch.utils.data.Dataset

描述数据集的抽象类。

所有表示从键到数据样本的映射的数据集都应该子类化该类。所有子类都应该重载__getitem__(),从而为给定键值获取数据样本。子类还可以选择性地重载__len__(),通过采样器实现和DataLoader的默认选项,返回数据集的大小。

CLASS

torch.utils.data.IterableDataset

iterable风格数据集

该数据集类适合处理那些数据不断流入的数据集

其子类需要重载__iter__(),返回一个数据集样本的迭代器。

当其子类与DataLoader一起使用时,数据集中的每项都将从DataLoader迭代器中生成。当num_workers > 0时,每个worker进程将具有不同的dataset对象副本,因此通常希望独立配置每个副本,以避免从workers返回重复的数据。在工作进程中调用get_worker_info()时,返回有关worker进程的信息。它可以在dataset的__iter__()方法中使用,也可以在DataLoader的worker_init_fn选项中使用,以修改每个副本的行为。

CLASS

torch.utils.data.TensorDataset(*tensors)

数据集包装张量。

每个样本都将通过索引第一个维度上的张量来检索。

CLASS

 torch.utils.data.ChainDataset(datasets)

用于链式链接多个IterableDataset 型数据集

该类用于组合多个dataset流。链式操作是在运行过程中实时完成的,在处理大规模数据集拼接时较为高效。

CLASS

 torch.utils.data.Subset(dataset, indices)

提取数据集中指定索引构成子集

torch.utils.data.random_split(dataset, lengths)

随机地将数据集分割为给定长度的非重叠的新数据集。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值