pytorch加载数据

pytorch用于加载数据集的模块主要是torch.utils.data(https://pytorch.org/docs/stable/data.html)。本文详细介绍了如何在自己的项目中(针对CV)使用torch.utils.data

1 综述

1.1 pytorch常规训练过程

我们一般使用一个for循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。
而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader对象,并且DataLoader是一个基于某个dataset的iterable,这个iterable每次从dataset中基于某种采样原则取出一个batch的数据。

1.2 torch.utils.data.DataLoader

这里先对torch.utils.data.DataLoader(下面简写为DataLoader)做一个简单讲解。
语法:

class torch.utils.data.DataLoader(
		dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 
		num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, 
		timeout=0, worker_init_fn=None, multiprocessing_context=None
		)

用处:
DataLoader将给定的sampler应用于给定的dataset,并在给定的dataset上提供一个iterable。给定的dataset可以是 map-style datasets 或 iterable-syle datasets;可以采用单进程或多进程加载数据。
参数:

  • dataset (Dataset):加载数据的数据集;
  • batch_size (int, optional):每个batch加载多少个样本,默认为1;
  • shuffle (bool, optional):如果为True,每个epoch都打乱一次数据顺序,默认为False
  • sampler (Sampler, optional):定义从dataset中抽样的策略,如果指定了,那必须有shuffle=False
  • batch_sampler (Sampler, optional):和sampler参数类似,只不过sampler一次返回一个样本的index,而batch_sampler一次返回一个batch的index,和batch_sizeshufflesamplerdrop_last都有关联;
  • num_workers (int, optional):采用多少个子进程来加载数据,num_workers=0表示在主进程中加载数据;
  • collate_fn (callable, optional):将一个batch的样本合并到一起形成一个Tensor,应用于 map-style datasets;
  • pin_memory (bool, optional):If True, the data loader will copy Tensors into CUDA pinned memory before returning them(暂时不理解);
  • drop_last (bool, optional):如果为True,那当dataset_size不能被batch_size整除时,丢弃最后一个不完整的batch的数据,如果为False,不丢弃,但最后一个batch的数据将会少一些,默认为False
  • timeout (numeric, optional):if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0);
  • worker_init_fn (callable, optional):If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

2 Dataset Types

DataLoader构造函数最重要的参数是dataset,它指定了加载数据的Dataset对象。Pytorch提供了两种不同的datasets:

  • map-style datasets
  • iterable-style datasets

2.1 Map-style datasets

  • map-style dataset:torch.utils.data.Dataset(下面简写为Dataset)的子类。实现了__getitem__()方法和__len__()方法,相当于构造了一个可以从 indices/keys 索引到 data samples 的 map。
  • 当使用dataset[idx]时会自动调用__getitem__()方法读出dataset中第idx个sample;当使用len(dataset)时会自动调用__len__()方法返回dataset中sample的数量。
  • 在pytorch中从抽象的基类torch.utils.data.Dataset定义自己的 map-style dataset。
    • 所有Dataset的子类都必须重写__getitem__()方法从而支持给定一个 index/key 返回一个 data sample;
    • 子类可以选择性地重写__len__()方法,重写了之后可以支持返回 dataset 的size。

示例:

  1. 2019年单帧图像去雨工作 PReNet 加载数据的方法(源码是https://github.com/csdwren/PReNet中的 DerainDataset.py )是先将所有图像放到 h5文件 中,然后构造了Dataset的子类并实现了__len__()方法和__getitem__()方法,而__getitem__()方法需要的 index/key 可以从 h5文件对象 的属性keys()获得。
  2. 更general的做法是将图片和标签的路径信息放到一个 txt文件中:
    1. 制作存储了图片和标签的路径信息的 txt文件;
    2. 将这些信息转化为list,此list的每个元素对应一个sample;
    3. 通过__getitem__()方法读取图片和标签并返回。

2.2 Iterable-style datasets

  • iterable-style dataset:torch.utils.data.IterableDataset(下面简写为IterableDataset)的子类。实现了__iter__()方法的torch.utils.data.IterableDataset的子类实例,是一个基于某个数据库的可迭代对象。
  • 适用于随机读取代价大或者不可能随机读取样本的情况,以及batch_size取决于所获得的数据的情况。
  • 在pytorch中从抽象的基类torch.utils.data.IterableDataset定义自己的 iterable-style dataset。在调用iter(dataset)(形成迭代器)时返回从数据库或远程服务器等读取的数据流。
  • IterableDataset的子类实例被用于DataLoader,这个dataset中的所有sample将在DataLoader迭代器产生出来,那么当DataLoadernum_workers > 0,即多进程加载数据时,每个子进程都会有一个dataset的copy,就有可能返回重复的sample,这种情况应该被避免。

由于目前还没有遇到过这种dataset,等遇到了再来详细补充细节。

3 Data Loading Order and Sampler

由于iterable-style datasets加载数据的顺序在定义可迭代对象时就已经确定了,所以下面只针对map-style dataset进行讲解。

  • 对于map-style datasets,使用torch.utils.data.Sampler(下文简称Sampler)的子类实例来指定加载数据样本时的indices/keys序列排序,从而控制加载数据的顺序,Sampler的子类实例表示的是在indices/keys序列上的可迭代对象,即所有Sampler的子类必须重写__iter__()方法,选择性重写__len__()方法。
  • DataLoader将会根据shuffle参数自动构造一个sequential sampler (shuffle=False)或shuffled sampler (shuffle=True),这里的sampler是Sampler的子类实例。
  • sampler=False时,可以给DataLoadersampler参数赋值自定义的Sampler子类实例,从而自定义加载数据的额顺序。
  • sampler=False时,可以给DataLoaderbatch_sampler参数赋值自定义的、一次返回一个batch的indices/keys的sampler子类实例。

3.1 torch.utils.data.Sampler

语法:

class torch.utils.data.Sampler(data_source)

功能:
所有sampler的基类。所有子类都必须重写__iter__()方法,提供一种在dataset elements的indices/keys上的迭代方法。选择性重写__len__()方法,返回这个迭代器的长度。
参数:

  • data_source (Dataset):dataset to sample from.

源码:
源码比较简单且帮助理解,这里放一下:

class Sampler(object):
	def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

3.2 torch.utils.data.SequentialSampler

语法:

class torch.utils.data.SequentialSampler(data_source)

功能:
顺序采样。
源码:

class SequentialSampler(Sampler):
	def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

3.3 torch.utils.data.RandomSampler

语法:

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)

功能:
随机采样。当DataLoadershuffle=True时,每个epoch都默认调用这个sampler打乱采样顺序。

3.4 其他Sampler子类

其他的Sampler子类不再一一说明,要使用的话参考:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler,这里只放语法:

class torch.utils.data.SubsetRandomSampler(indices)
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)

4 Loading Batched and Non-Batched Data

DataLoader支持通过batch_sizedrop_lastbatch_sampler参数自动将单个提取的数据整理为一个batch。

4.1 Automatic batching(default)

  • 默认情况下,DataLoader会自动将获取的一个batch的单独的数据整理到一起形成一个Tensor,默认情况batch_size是第一个维度。
  • DataLoader的参数batch_size不为None时,DataLoader自动产生被整理为batch的数据而不是单个的samples。
  • batch_sizedrop_last参数共同指定data loader如何获得每个batch的数据的keys。
  • 对应map-style datasets,我们也可以选择性的指定batch_sampler参数来一次产生一个batch的随机keys。
    更多内容以后补充。

4.2 Disable automatic batching

5 Single- and Multi-process Data Loading

5.1 Single-process data loading (default)

5.2 Multi-process data loading

6 Memory Pinning

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值