Pytorch DataLoader shuffle 参数源码解读

DataLoader 的使用方法:

  1. 调用 dataloader. __iter__ 获取迭代器
  2. 调用 迭代器的 __next__ 获取下一个 batch

首先 dataloader 可以设置是否 shuffle
那么只要看 shuffle 参数对这个过程有什么影响即可

class DataLoader(Generic[T_co]):
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: _collate_fn_t = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):

注释:

shuffle (bool, optional): set to True to have the data reshuffled at every epoch (default: False).

可以看到数据会在每个 epoch 中被 reshuffle。
其实现中,直接相关的代码有:

if shuffle:
     sampler = RandomSampler(dataset, generator=generator)  # type: ignore
else:
     sampler = SequentialSampler(dataset)

# ......
self.sampler = sampler

若设置了 shuffle 为 True 则将采样实例化为随机采样。
当调用 dataloader 的 iter() :

def __iter__(self) -> '_BaseDataLoaderIter':
    if self.persistent_workers and self.num_workers > 0:
        if self._iterator is None:
            self._iterator = self._get_iterator()
        else:
            self._iterator._reset(self)
        return self._iterator
    else:
        return self._get_iterator()

会调用 self._get_iterator() 方法,并返回生成的 iterator,用户通过调用该 iterator 的 next() 方法获取每个 batch。

def _get_iterator(self) -> '_BaseDataLoaderIter':
   if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
   else:
        self.check_worker_number_rationality()
        return _MultiProcessingDataLoaderIter(self)

该方法先判断是否多线程,然后选择生成对应的 DataLoaderIter,以单线程为例:

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

没有看到该类中的 __next__实现,看其父类 _BaseDataLoaderIter ,无关代码全删了。

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._index_sampler = loader._index_sampler
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        # 以及其他属性

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()
            self._num_yielded += 1
            return data

_BaseDataLoaderIter 调用了 dataloader 的 _index_sampler 方法,获取了一个 sampler,并调用其 __iter__得到一个迭代器,以迭代获取 batch 中元素的下标。

当用户调用 _BaseDataLoaderIter 其子类的 __next__ 时,调用由子类实现的 _next_data() 方法获取 batch,并然给当前的计数 +1。

当该迭代器用完的时候就会调用 self._reset() 重新生成一个新的迭代器,并将计数归零。

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called
_next_data() 在子类 _SingleProcessDataLoaderIter 中重写 :
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

逻辑为调用 self._next_index() 方法获取当前要返回的 batch 中元素,在 dataset 中的下标 index,然后使用 self._dataset_fetcher.fetch(index) 获取对应元素后返回即可。
_next_index 在父类中实现:

def _next_index(self):
    return next(self._sampler_iter) 

其实就是调用之前得到的迭代器的 __next__ 以获取下一组 index。

那么 sampler 是怎样获取 index 的呢?以下是之前出现过的 RandomSampler 的代码:
class RandomSampler(Sampler[int]):

def __init__(self, data_source: Sized, replacement: bool = False,
             num_samples: Optional[int] = None, generator=None) -> None:
    self.data_source = data_source
    self.replacement = replacement
    self._num_samples = num_samples
    self.generator = generator

def __iter__(self):
    n = len(self.data_source)
    if self.generator is None:
        generator = torch.Generator()
        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
    else:
        generator = self.generator
    if self.replacement:
        for _ in range(self.num_samples // 32):
            yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
        yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
    else:
        yield from torch.randperm(n, generator=self.generator).tolist()

其实只要看一行就够了

yield from torch.randperm(n, generator=self.generator).tolist()

torch.randperm 就是返回元素为从 0 到 n-1 的张量,将其做 tolist 之后再使用 yield from 将其内的元素一个个 yield 出来,所以返回的就是 0 到 n-1的随机数字,且不重复。

但是!DataLoaderIter 里头用的并不单单是这个 sampler,这是由于若使用该 sampler, 一次只能获取一条数据,但是咱们是一个 batch, 一个 batch 有多条数据。
回看_BaseDataLoaderIter 中, loader._index_sampler 调用的是以下方法

@property
def _index_sampler(self):
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

self._auto_collation = True 时 就会选择 self.batch_sampler

@property
def _auto_collation(self):
    return self.batch_sampler is not None

self._auto_collation 什么时候等于 True 呢? 当 self.batch_sampler is not None。那 self.batch_sampler 啥时候会为 None 呢?

if batch_size is not None and batch_sampler is None:
     batch_sampler = BatchSampler(sampler, batch_size, drop_last)

当 batch_size is None 的时候,好吧跟我们现在的情况没什么关系,_index_sampler 返回的就是 self.batch_sampler。

batch_sampler 在实例化时需要传入 sampler, batch_size, drop_last, 此处的 sampler 就是之前的 RandomSampler。batch_sampler 的 __iter__ 实现为:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

首先从获取了 sampler 的迭代器,并且逐个获取 sampler 中的内容也就是数据的下标,当获取的数量与 batch size 相同时将会 yield 出去。

好了那么就非常清楚了:

  1. 调用 dataLoader.__iter__ 得到生成器 dataLoaderIter
  2. 用户调用 dataLoaderIter 的 __next__ 想要获取下一个 batch
  3. dataLoaderIter 调用 dataLoader 的 sampler.__next__ 获取一组下标用于从 dataset 获取 data
  4. dataLoader 的 sampler 是一个 batchSampler,它从 randomSampler 中获取所有打乱的 index
  5. batchSampler 根据 batch size 将 index 一组一组地交给 dataLoaderIter
  6. dataLoaderIter 调用 _dataset_fetcher.fetch(index) 获取该 batch 的 data 返回给用户

以上就是当 shuffle = True 时的获取逻辑了。当然,当 shuffle = False 时,使用的就不是 RandomSampler 而是 SequentialSampler 了,那么 index 就是顺序不变的,batch 内的元素也就不变了

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch DataLoader是一个用于批量加载数据的工具,它可以帮助用户在训练模型时高效地加载和处理大规模数据集。DataLoader可以根据用户定义的批量大小、采样方法、并行加载等参数来自动将数据集分成小批量,并且可以在GPU上并行加载数据以提高训练效率。 使用DataLoader需要先定义一个数据集对象,然后将其传递给DataLoader。常用的数据集对象包括PyTorch自带的Dataset类和用户自定义的数据集类。在DataLoader中可以指定批量大小、是否打乱数据、并行加载等参数。 下面是一个示例代码: ```python import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self): self.data = torch.randn(100, 10) self.label = torch.randint(0, 2, size=(100,)) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) dataset = MyDataset() dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2) for data, label in dataloader: print(data.shape, label.shape) ``` 在上面的示例中,我们定义了一个自己的数据集类MyDataset,并将其传递给DataLoader。然后指定了批量大小为10,打乱数据,使用2个进程来并行加载数据。在循环中,每次从DataLoader中取出一个批量的数据和标签,并输出它们的形状。 ### 回答2: PyTorchDataLoader是一个用于加载数据的实用工具。它可以帮助我们高效地加载和预处理数据,以供深度学习模型使用。 DataLoader有几个重要参数。首先是dataset,它定义了我们要加载的原始数据集。PyTorch提供了几种内置的数据集类型,也可以自定义数据集。数据集可以是图片、文本、音频等。 另一个重要参数是batch_size,它定义了每个批次中加载的数据样本数量。这是非常重要的,因为深度学习模型通常需要在一个批次上进行并行计算。较大的批次可以提高模型的训练速度,但可能需要更多的内存。 DataLoader还支持多线程数据加载。我们可以使用num_workers参数来指定并行加载数据的线程数。这可以加快数据加载的速度,特别是当数据集很大时。 此外,DataLoader还支持数据的随机打乱。我们可以将shuffle参数设置为True,在每个轮次开始时随机重新排序数据。这对于训练深度学习模型非常重要,因为通过在不同轮次中提供不同样本的顺序,可以增加模型的泛化能力。 在使用DataLoader加载数据后,我们可以通过迭代器的方式逐批次地获取数据样本。每个样本都是一个数据批次,包含了输入数据和对应的标签。 总的来说,PyTorchDataLoader提供了一个简单而强大的工具,用于加载和预处理数据以供深度学习模型使用。它的灵活性和可定制性使得我们可以根据实际需求对数据进行处理,并且能够高效地并行加载数据,提高了训练的速度。 ### 回答3: PyTorchDataLoader是一个用于数据加载和预处理的实用程序类。它可以帮助我们更有效地加载和处理数据集,并将其用于训练和评估深度学习模型。 DataLoader的主要功能包括以下几个方面: 1. 数据加载:DataLoader可以从不同的数据中加载数据,例如文件系统、内存、数据库等。它接受一个数据集对象作为输入,该数据集对象包含实际的数据和对应的标签。DataLoader可以根据需要将数据集分成小批量加载到内存中,以减少内存占用和加速训练过程。 2. 数据预处理:DataLoader可以在加载数据之前对数据进行各种预处理操作,包括数据增强、标准化、裁剪和缩放等。这些预处理操作可以提高模型的泛化能力和训练效果。 3. 数据迭代:DataLoader将数据集划分为若干个小批量,并提供一个可迭代的对象,使得我们可以使用for循环逐个访问这些小批量。这种迭代方式使得我们能够更方便地按批次处理数据,而无需手动编写批处理循环。 4. 数据并行加载:DataLoader支持在多个CPU核心上并行加载数据,以提高数据加载的效率。它使用多线程和预读取的机制,在一个线程中预先加载数据,而另一个线程处理模型的训练或推理过程。 总之,PyTorchDataLoader是一个方便且高效的工具,帮助我们更好地管理和处理数据集。它可以加速深度学习模型的训练过程,并提供了一种简单而灵活的数据加载和迭代方式。使用DataLoader可以让我们更专注于模型的设计和调优,而无需过多关注数据的处理和加载细节。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值