【pytorch】Dataloader学习笔记

Outline

  • Pytorch中加载数据集的核心类为torch.utils.data.Dataloder,其作用是加载,并将torch.utils.data.Dataset中的元素转为tensor数据类型。加载Dataset中的元素、并控制Dataset元素加载次序由Sampler或者BatchSampler类控制。将Dataset中的元素转为torch.tensor类型由collator_fn可调用对象控制。

  • 参数Dataset表示待处理的源数据集。Dataloader支持两种类型的Dataset——“map-style” 与 “iterable-style”。iterable-style Dataset可以理解为一个迭代器,每次输出一个或者一组Dataset中的元素。

  • 由于iterable-style Dataset自定义程度较高,本文主要焦距于map-style Dataset。iterable-style类型的处理逻辑大致如下:

    dataset = iter(dataset)
    
    # Non-Batched mode
    for data in dataset:
    	collator_fn(data)
    
    # Batched mode
    for indices in batch_sampler:
    	collator_fn([next(dataset) for idx in indices]
    
  • map-style Dataset可以理解为“每一Dataset中的元素值都可以通过一个key获取”, 如dataset[idx]

  • 注意,同样是每次迭代中处理一条样本, batch_size = 1与 batch_size=None是不同的,前者在创建的tensor中会新建一个batch_size维度。

Dataset Type

Dataset是Dataloader实例化中最重要的参数,代表了待处理的数据集。DataLoader支持map-style与iterable-style两种类型的Dataset。

map-style

map-style dataset represents a map from key to data sample

map-style类型的数据集类需要实现__getitem__与__len__协议。这种类型的数据集通过key就可以取到(Fetch)对应的样本数据。torch.utils.data.Dataset是map-style类型的代表, 如果需要自定义map-style 数据集类,应该继承torch.utils.data.Dataset, 并重实现__getitem____len__。 Dataset的部分源码如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
  • torch.utils.data.Dataset 只是map-style类型的一个基础类。如果不愿自定义, 可使用TensorDataset类, 其部分源码如下:

    class TensorDataset(Dataset[Tuple[Tensor, ...]]):
        r"""Dataset wrapping tensors.
        Each sample will be retrieved by indexing tensors along the first dimension.
        Args:
            *tensors (Tensor): tensors that have the same size of the first dimension.
        """
        tensors: Tuple[Tensor, ...]
    
        def __init__(self, *tensors: Tensor) -> None:
            assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
            self.tensors = tensors
    
        def __getitem__(self, index): 
        	""" 本质是依次对每个传入的tensor,在第一个维度根据指定索引键取值,然后将所有值以元组组装起来。 """
            return tuple(tensor[index] for tensor in self.tensors)
    
    
  • Dataloader默认创建的sampler或者batch_sampler,只会输出整数型索引——样本在Dataset中的索引编号。如果自定义型的map-style数据集需要通过非整数型key来获取样本,需要创建自定义类型的sampler。

iterable-style

iterable-style dataset represent an iterable of data samples。

iterable-style 类型的dataset类需要实现__iter__协议。torch.utils.data.IterableDataset是该类型的典型代表。

Data Loading Order

在Dataloader中,样本数据的加载顺序主要针对map-style类型的数据集,iterable-style类型的数据集是按照自定义的顺序依次输出数据,输出次序的逻辑由IterableDataset子类通过__iter__控制。对于map-style类型数据集:

  • 通过sampler生成器控制键生成的次序,然后通过键控制数据的加载顺序。
  • sampler生成器按照每次迭代返回键的数量,可以分为每次迭代返回一个索引键、或者一组索引键,分别对应 Batched 与 Non-Batched 数据加载模式

Sampler

Sampler是一个抽象基类,子类通过自定义__iter__方法,返回Dataset元素键集合的迭代器,其源码如下所示:

class Sampler(Generic[T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized]) -> None:
        pass

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

比较常用的sampler包括SequentialSampler、RandomSampler、BatchSampler。

SequentialSampler

从名称可以看出,SequentialSampler会生成次序的索引序列,因此数据集加载顺序总是固定的,其部分源码如下所示:

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.
    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

RandomSampler

从名称可以看出,RandomSampler生成随机索引序列,其部分源码如下:

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:  # 重复采样
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:  # 非重复采样
            yield from torch.randperm(n, generator=generator).tolist() 

BatchSampler

BatchSampler每次迭代返回一组索引编号。

class BatchSampler(Sampler[List[int]]):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
        if self.drop_last:
            sampler_iter = iter(self.sampler)
            while True:
                try:
                    batch = [next(sampler_iter) for _ in range(self.batch_size)]
                    yield batch
                except StopIteration:
                    break
        else:
            batch = [0] * self.batch_size
            idx_in_batch = 0
            for idx in self.sampler:
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.batch_size:
                    yield batch
                    idx_in_batch = 0
                    batch = [0] * self.batch_size
            if idx_in_batch > 0:
                yield batch[:idx_in_batch]

Loading Batched or Non-Batched

如果sampler单次返回一个索引值,则Dataloader每次迭代处理一条样本。如果sampler每次迭代返回一组索引值,则Dataloader每次迭代处理一个batch的样本。默认会以Batched模式进行样本加载与处理,如果需要以单条样本进行处理,需要设置batch_size=None并且batch_sampler=None

collator_fn

collator_fn是处理Dataset元素的最后一步,对于map-style类型的数据集,collator_fn的作用可类比:

# Non-batched 
for index in sampler:
    yield collate_fn(dataset[index])

# Batched
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

Single-Process and Multi-Process Data Loading

  • 默认采用单进程方式
  • TODO

 

参考资料

TORCH.UTILS.DATA official document

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch DataLoader是一个用于批量加载数据的工具,它可以帮助用户在训练模型时高效地加载和处理大规模数据集。DataLoader可以根据用户定义的批量大小、采样方法、并行加载等参数来自动将数据集分成小批量,并且可以在GPU上并行加载数据以提高训练效率。 使用DataLoader需要先定义一个数据集对象,然后将其传递给DataLoader。常用的数据集对象包括PyTorch自带的Dataset类和用户自定义的数据集类。在DataLoader中可以指定批量大小、是否打乱数据、并行加载等参数。 下面是一个示例代码: ```python import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self): self.data = torch.randn(100, 10) self.label = torch.randint(0, 2, size=(100,)) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) dataset = MyDataset() dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2) for data, label in dataloader: print(data.shape, label.shape) ``` 在上面的示例中,我们定义了一个自己的数据集类MyDataset,并将其传递给DataLoader。然后指定了批量大小为10,打乱数据,使用2个进程来并行加载数据。在循环中,每次从DataLoader中取出一个批量的数据和标签,并输出它们的形状。 ### 回答2: PyTorchDataLoader是一个用于加载数据的实用工具。它可以帮助我们高效地加载和预处理数据,以供深度学习模型使用。 DataLoader有几个重要参数。首先是dataset,它定义了我们要加载的原始数据集。PyTorch提供了几种内置的数据集类型,也可以自定义数据集。数据集可以是图片、文本、音频等。 另一个重要参数是batch_size,它定义了每个批次中加载的数据样本数量。这是非常重要的,因为深度学习模型通常需要在一个批次上进行并行计算。较大的批次可以提高模型的训练速度,但可能需要更多的内存。 DataLoader还支持多线程数据加载。我们可以使用num_workers参数来指定并行加载数据的线程数。这可以加快数据加载的速度,特别是当数据集很大时。 此外,DataLoader还支持数据的随机打乱。我们可以将shuffle参数设置为True,在每个轮次开始时随机重新排序数据。这对于训练深度学习模型非常重要,因为通过在不同轮次中提供不同样本的顺序,可以增加模型的泛化能力。 在使用DataLoader加载数据后,我们可以通过迭代器的方式逐批次地获取数据样本。每个样本都是一个数据批次,包含了输入数据和对应的标签。 总的来说,PyTorchDataLoader提供了一个简单而强大的工具,用于加载和预处理数据以供深度学习模型使用。它的灵活性和可定制性使得我们可以根据实际需求对数据进行处理,并且能够高效地并行加载数据,提高了训练的速度。 ### 回答3: PyTorchDataLoader是一个用于数据加载和预处理的实用程序类。它可以帮助我们更有效地加载和处理数据集,并将其用于训练和评估深度学习模型。 DataLoader的主要功能包括以下几个方面: 1. 数据加载:DataLoader可以从不同的数据源中加载数据,例如文件系统、内存、数据库等。它接受一个数据集对象作为输入,该数据集对象包含实际的数据和对应的标签。DataLoader可以根据需要将数据集分成小批量加载到内存中,以减少内存占用和加速训练过程。 2. 数据预处理:DataLoader可以在加载数据之前对数据进行各种预处理操作,包括数据增强、标准化、裁剪和缩放等。这些预处理操作可以提高模型的泛化能力和训练效果。 3. 数据迭代:DataLoader将数据集划分为若干个小批量,并提供一个可迭代的对象,使得我们可以使用for循环逐个访问这些小批量。这种迭代方式使得我们能够更方便地按批次处理数据,而无需手动编写批处理循环。 4. 数据并行加载:DataLoader支持在多个CPU核心上并行加载数据,以提高数据加载的效率。它使用多线程和预读取的机制,在一个线程中预先加载数据,而另一个线程处理模型的训练或推理过程。 总之,PyTorchDataLoader是一个方便且高效的工具,帮助我们更好地管理和处理数据集。它可以加速深度学习模型的训练过程,并提供了一种简单而灵活的数据加载和迭代方式。使用DataLoader可以让我们更专注于模型的设计和调优,而无需过多关注数据的处理和加载细节。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值