输入流水线
pytorch
的输入流水线的操作顺序是这样的:
- 创建一个 Dataset 对象
- 创建一个 DataLoader 对象
- 不停的循环这个 DataLoader 对象
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....
在之前文章也提到过,如果现有的 Dataset
不能够满足需求,我们也可以自定义 Dataset
,通过继承 torch.utils.data.Dataset
。在继承的时候,需要 override
三个方法。
__init__
: 用来初始化数据集__getitem__
__len__
从本文中,您可以看到 __getitem__
和 __len__
在 DataLoader
中是如何被使用的。
DataLoader
从DataLoader
看起,下面是源码。为了方便起见,采用在源码中添加注释的形式进行解读。
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
sampler = RandomSampler(dataset)
else:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
sampler = SequentialSampler(dataset)
# Sampler 是个迭代器,一次之只返回一个 索引
# BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
# 以下两个代码是等价的
for data in dataloader:
...
# 等价与
iterr = iter(dataloader)
while True:
try:
next(iterr)
except StopIteration:
break
在 DataLoader
中,iter(dataloader)
返回的是一个 DataLoaderIter
对象, 这个才是我们一直 next
的 对象。
下面会先介绍一下 几个 Sampler
, 然后介绍 核心部分 DataLoaderIter
。
RandomSampler, SequentialSampler, BatchSampler
首先,是