pytorch用于加载数据集的模块主要是torch.utils.data
(https://pytorch.org/docs/stable/data.html)。本文详细介绍了如何在自己的项目中(针对CV)使用torch.utils.data
。
1 综述
1.1 pytorch常规训练过程
我们一般使用一个for
循环(或多层的)来训练神经网络,每一次迭代,加载一个batch的数据,神经网络前向反向传播各一次并更新一次参数。
而这个过程中加载一个batch的数据这一步需要使用一个torch.utils.data.DataLoader
对象,并且DataLoader
是一个基于某个dataset的iterable
,这个iterable
每次从dataset中基于某种采样原则取出一个batch的数据。
1.2 torch.utils.data.DataLoader
这里先对torch.utils.data.DataLoader
(下面简写为DataLoader
)做一个简单讲解。
语法:
class torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, multiprocessing_context=None
)
用处:
DataLoader
将给定的sampler
应用于给定的dataset
,并在给定的dataset
上提供一个iterable
。给定的dataset
可以是 map-style datasets 或 iterable-syle datasets;可以采用单进程或多进程加载数据。
参数:
- dataset (
Dataset
):加载数据的数据集; - batch_size (
int
, optional):每个batch加载多少个样本,默认为1; - shuffle (
bool
, optional):如果为True
,每个epoch都打乱一次数据顺序,默认为False
; - sampler (
Sampler
, optional):定义从dataset中抽样的策略,如果指定了,那必须有shuffle=False
; - batch_sampler (
Sampler
, optional):和sampler
参数类似,只不过sampler
一次返回一个样本的index,而batch_sampler
一次返回一个batch的index,和batch_size
、shuffle
、sampler
、drop_last
都有关联; - num_workers (
int
, optional):采用多少个子进程来加载数据,num_workers=0
表示在主进程中加载数据; - collate_fn (callable, optional):将一个batch的样本合并到一起形成一个Tensor,应用于 map-style datasets;
- pin_memory (
bool
, optional):If True, the data loader will copy Tensors into CUDA pinned memory before returning them(暂时不理解); - drop_last (
bool
, optional):如果为True
,那当dataset_size不能被batch_size整除时,丢弃最后一个不完整的batch的数据,如果为False
,不丢弃,但最后一个batch的数据将会少一些,默认为False
; - timeout (numeric, optional):if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0);
- worker_init_fn (callable, optional):If not
None
, this will be called on each worker subprocess with the worker id (an int in[0, num_workers - 1]
) as input, after seeding and before data loading. (default:None
)
2 Dataset Types
DataLoader
构造函数最重要的参数是dataset
,它指定了加载数据的Dataset
对象。Pytorch提供了两种不同的datasets:
- map-style datasets
- iterable-style datasets
2.1 Map-style datasets
- map-style dataset:
torch.utils.data.Dataset
(下面简写为Dataset
)的子类。实现了__getitem__()
方法和__len__()
方法,相当于构造了一个可以从 indices/keys 索引到 data samples 的 map。 - 当使用
dataset[idx]
时会自动调用__getitem__()
方法读出dataset中第idx
个sample;当使用len(dataset)
时会自动调用__len__()
方法返回dataset中sample的数量。 - 在pytorch中从抽象的基类
torch.utils.data.Dataset
定义自己的 map-style dataset。- 所有
Dataset
的子类都必须重写__getitem__()
方法从而支持给定一个 index/key 返回一个 data sample; - 子类可以选择性地重写
__len__()
方法,重写了之后可以支持返回 dataset 的size。
- 所有
示例:
- 2019年单帧图像去雨工作 PReNet 加载数据的方法(源码是https://github.com/csdwren/PReNet中的 DerainDataset.py )是先将所有图像放到 h5文件 中,然后构造了
Dataset
的子类并实现了__len__()
方法和__getitem__()
方法,而__getitem__()
方法需要的 index/key 可以从 h5文件对象 的属性keys()
获得。 - 更general的做法是将图片和标签的路径信息放到一个 txt文件中:
- 制作存储了图片和标签的路径信息的 txt文件;
- 将这些信息转化为
list
,此list
的每个元素对应一个sample; - 通过
__getitem__()
方法读取图片和标签并返回。
2.2 Iterable-style datasets
- iterable-style dataset:
torch.utils.data.IterableDataset
(下面简写为IterableDataset
)的子类。实现了__iter__()
方法的torch.utils.data.IterableDataset
的子类实例,是一个基于某个数据库的可迭代对象。 - 适用于随机读取代价大或者不可能随机读取样本的情况,以及batch_size取决于所获得的数据的情况。
- 在pytorch中从抽象的基类
torch.utils.data.IterableDataset
定义自己的 iterable-style dataset。在调用iter(dataset)
(形成迭代器)时返回从数据库或远程服务器等读取的数据流。 - 当
IterableDataset
的子类实例被用于DataLoader
,这个dataset中的所有sample将在DataLoader
迭代器产生出来,那么当DataLoader
的num_workers > 0
,即多进程加载数据时,每个子进程都会有一个dataset的copy,就有可能返回重复的sample,这种情况应该被避免。
由于目前还没有遇到过这种dataset,等遇到了再来详细补充细节。
3 Data Loading Order and Sampler
由于iterable-style datasets加载数据的顺序在定义可迭代对象时就已经确定了,所以下面只针对map-style dataset进行讲解。
- 对于map-style datasets,使用
torch.utils.data.Sampler
(下文简称Sampler
)的子类实例来指定加载数据样本时的indices/keys序列排序,从而控制加载数据的顺序,Sampler
的子类实例表示的是在indices/keys序列上的可迭代对象,即所有Sampler
的子类必须重写__iter__()
方法,选择性重写__len__()
方法。 DataLoader
将会根据shuffle
参数自动构造一个sequential sampler (shuffle=False
)或shuffled sampler (shuffle=True
),这里的sampler是Sampler
的子类实例。- 在
sampler=False
时,可以给DataLoader
的sampler
参数赋值自定义的Sampler
子类实例,从而自定义加载数据的额顺序。 - 在
sampler=False
时,可以给DataLoader
的batch_sampler
参数赋值自定义的、一次返回一个batch的indices/keys的sampler
子类实例。
3.1 torch.utils.data.Sampler
语法:
class torch.utils.data.Sampler(data_source)
功能:
所有sampler的基类。所有子类都必须重写__iter__()
方法,提供一种在dataset elements的indices/keys上的迭代方法。选择性重写__len__()
方法,返回这个迭代器的长度。
参数:
- data_source (
Dataset
):dataset to sample from.
源码:
源码比较简单且帮助理解,这里放一下:
class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
3.2 torch.utils.data.SequentialSampler
语法:
class torch.utils.data.SequentialSampler(data_source)
功能:
顺序采样。
源码:
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
3.3 torch.utils.data.RandomSampler
语法:
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
功能:
随机采样。当DataLoader
的shuffle=True
时,每个epoch都默认调用这个sampler打乱采样顺序。
3.4 其他Sampler
子类
其他的Sampler
子类不再一一说明,要使用的话参考:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler,这里只放语法:
class torch.utils.data.SubsetRandomSampler(indices)
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)
4 Loading Batched and Non-Batched Data
DataLoader
支持通过batch_size
、drop_last
和batch_sampler
参数自动将单个提取的数据整理为一个batch。
4.1 Automatic batching(default)
- 默认情况下,
DataLoader
会自动将获取的一个batch的单独的数据整理到一起形成一个Tensor,默认情况batch_size
是第一个维度。 - 当
DataLoader
的参数batch_size
不为None
时,DataLoader
自动产生被整理为batch的数据而不是单个的samples。 batch_size
和drop_last
参数共同指定data loader如何获得每个batch的数据的keys。- 对应map-style datasets,我们也可以选择性的指定
batch_sampler
参数来一次产生一个batch的随机keys。
更多内容以后补充。