# 自上而下理解三者关系

class DataLoader(object):
...

def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter)  # Sampler
batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch

# Sampler

## 参数传递

class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已经实现的Sampler有如下几种：

• SequentialSampler
• RandomSampler
• WeightedSampler
• SubsetRandomSampler

• 如果你自定义了batch_sampler,那么这些参数都必须使用默认值：batch_size, shuffle,sampler,drop_last.
• 如果你自定义了sampler，那么shuffle需要设置为False
• 如果samplerbatch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况：
• shuffle=True,则sampler=RandomSampler(dataset)
• shuffle=False,则sampler=SequentialSampler(dataset)

## 如何自定义Sampler和BatchSampler？

class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:__iter__ method, providing a
way to iterate over indices of dataset elements, and a :meth:__len__ method
that returns the length of the returned iterators.
.. note:: The :meth:__len__ method isn't strictly required by
:class:~torch.utils.data.DataLoader, but is expected in any
calculation involving the length of a :class:~torch.utils.data.DataLoader.
"""

def __init__(self, data_source):
pass

def __iter__(self):
raise NotImplementedError

def __len__(self):
return len(self.data_source)

# Dataset

Dataset定义方式如下：

class Dataset(object):
def __init__(self):
...

def __getitem__(self, index):
return ...

def __len__(self):
return ...

class DataLoader(object):
...

def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter)
batch = self.collate_fn([self.dataset[i] for i in indices]) # this line
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch

• indices: 表示每一个iteration，sampler返回的indices，即一个batch size大小的索引列表
• self.dataset[i]: 前面已经介绍了，这里就是对第i个数据进行读取操作，一般来说self.dataset[i]=(img, label)

• 21
点赞
• 46
收藏
• 0
评论
07-29 1734
12-15 372
08-07 650
10-07 6799
11-23 8032
09-18 9205

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助