Outline
-
Pytorch中加载数据集的核心类为torch.utils.data.Dataloder,其作用是加载,并将torch.utils.data.Dataset中的元素转为tensor数据类型。加载Dataset中的元素、并控制Dataset元素加载次序由Sampler或者BatchSampler类控制。将Dataset中的元素转为torch.tensor类型由collator_fn可调用对象控制。
-
参数Dataset表示待处理的源数据集。Dataloader支持两种类型的Dataset——“map-style” 与 “iterable-style”。iterable-style Dataset可以理解为一个迭代器,每次输出一个或者一组Dataset中的元素。
-
由于iterable-style Dataset自定义程度较高,本文主要焦距于map-style Dataset。iterable-style类型的处理逻辑大致如下:
dataset = iter(dataset) # Non-Batched mode for data in dataset: collator_fn(data) # Batched mode for indices in batch_sampler: collator_fn([next(dataset) for idx in indices]
-
map-style Dataset可以理解为“每一Dataset中的元素值都可以通过一个key获取”, 如
dataset[idx]
。 -
注意,同样是每次迭代中处理一条样本, batch_size = 1与 batch_size=None是不同的,前者在创建的tensor中会新建一个batch_size维度。
Dataset Type
Dataset是Dataloader实例化中最重要的参数,代表了待处理的数据集。DataLoader支持map-style与iterable-style两种类型的Dataset。
map-style
map-style dataset represents a map from key to data sample
map-style类型的数据集类需要实现__getitem__与__len__协议。这种类型的数据集通过key就可以取到(Fetch)对应的样本数据。torch.utils.data.Dataset是map-style类型的代表, 如果需要自定义map-style 数据集类,应该继承torch.utils.data.Dataset, 并重实现__getitem__
与__len__
。 Dataset的部分源码如下:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
-
torch.utils.data.Dataset 只是map-style类型的一个基础类。如果不愿自定义, 可使用TensorDataset类, 其部分源码如下:
class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): """ 本质是依次对每个传入的tensor,在第一个维度根据指定索引键取值,然后将所有值以元组组装起来。 """ return tuple(tensor[index] for tensor in self.tensors)
-
Dataloader默认创建的sampler或者batch_sampler,只会输出整数型索引——样本在Dataset中的索引编号。如果自定义型的map-style数据集需要通过非整数型key来获取样本,需要创建自定义类型的sampler。
iterable-style
iterable-style dataset represent an iterable of data samples。
iterable-style 类型的dataset类需要实现__iter__协议。torch.utils.data.IterableDataset是该类型的典型代表。
Data Loading Order
在Dataloader中,样本数据的加载顺序主要针对map-style类型的数据集,iterable-style类型的数据集是按照自定义的顺序依次输出数据,输出次序的逻辑由IterableDataset子类通过__iter__控制。对于map-style类型数据集:
- 通过sampler生成器控制键生成的次序,然后通过键控制数据的加载顺序。
- sampler生成器按照每次迭代返回键的数量,可以分为每次迭代返回一个索引键、或者一组索引键,分别对应 Batched 与 Non-Batched 数据加载模式。
Sampler
Sampler是一个抽象基类,子类通过自定义__iter__方法,返回Dataset元素键集合的迭代器,其源码如下所示:
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
比较常用的sampler包括SequentialSampler、RandomSampler、BatchSampler。
SequentialSampler
从名称可以看出,SequentialSampler会生成次序的索引序列,因此数据集加载顺序总是固定的,其部分源码如下所示:
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
RandomSampler
从名称可以看出,RandomSampler生成随机索引序列,其部分源码如下:
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement: # 重复采样
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else: # 非重复采样
yield from torch.randperm(n, generator=generator).tolist()
BatchSampler
BatchSampler每次迭代返回一组索引编号。
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
Loading Batched or Non-Batched
如果sampler单次返回一个索引值,则Dataloader每次迭代处理一条样本。如果sampler每次迭代返回一组索引值,则Dataloader每次迭代处理一个batch的样本。默认会以Batched模式进行样本加载与处理,如果需要以单条样本进行处理,需要设置batch_size=None
并且batch_sampler=None
。
collator_fn
collator_fn是处理Dataset元素的最后一步,对于map-style类型的数据集,collator_fn的作用可类比:
# Non-batched
for index in sampler:
yield collate_fn(dataset[index])
# Batched
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
Single-Process and Multi-Process Data Loading
- 默认采用单进程方式
- TODO