PyTorch学习 数据加载(Dataset、DataLoader)模块介绍及源码分析

Dataset

Dataset类在torch.util.data里定义,所以引用方式为from torch.util.data import Dataset

Dataset类定义的操作需要完成:对单个样本完成读取,以及某些可能进行的预处理

对于Dataset类,我们需要完成三个方法:__init__,__getitem__,__len__

方法名作用
__init__(self, *loader_args, **loader_kwargs)完成Dataset类的初始化
__getitem__(self, index)基于索引返回某个样本(sample, label)
__len__(self)返回所有样本个数

以covid19数据集加载举例(LiHongYee,MLSpring2022HW1)

class COVID19Dataset(Dataset):
    def __init__(self,
                 covid_features,
                 covid_labels,
                 select_features=None,
                 select_features_model=None):
        self.covid_features = np.array(covid_features)
        self.covid_labels = covid_labels
        self.select_features = select_features
        self.select_features_model = select_features_model

        if select_features is not None and select_features_model is not None:
            self.covid_features = self.select_features_model.transform(self.covid_features)

        self.covid_features = torch.from_numpy(self.covid_features).float()
        if self.covid_labels is not None:
            self.covid_labels = torch.from_numpy(np.array(self.covid_labels)).float()

        self.input_dim = self.covid_features.shape[1]

    def __getitem__(self, index):
        if self.covid_labels is None:
            return self.covid_features[index]
        else:
            return self.covid_features[index], self.covid_labels[index]

    def __len__(self):
        return len(self.covid_features)

DataLoader

DataLoader在torch.util.data里定义,所以引用方式为from torch.util.data import DataLoader

DataLoader类定义的操作需要完成:将Dataset里的单个样本处理成mini batch

对于DataLoader类,如果要自定义,则一般需要完成__init__和__len__方法。如果无需更多配置,则将自定义的Dataset类传入DataLoader即可

DataLoader参数

dataset (Dataset): dataset from which to load the data. 
自定义的Dataset

batch_size (int, optional): how many samples per batch to load (default: ``1``). 
mini batch的大小,通常把batch_size改大一点,为2的整数次幂

shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). 
在每轮训练后,将数据集打乱

sampler (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__`` implemented. If specified, :attr:`shuffle` must not be specified.
自定义方法(某种顺序)从Dataset中取样本,指定这个参数就不能设置shuffle
指定shuffle相当于使用内置的RandomSampler进行采样,否则使用SequentialSampler
RandomSampler的__iter__方法有一行代码:yield from torch.randperm(n, generator=self.generator).tolist()
SequentialSampler: return iter(range(len(self.data_source))),均继承了Sampler[int]

batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. 
返回一个batch的索引,与batch_size, shuffle, sampler, drop_last互斥
传入了batch_sampler,相当于已经告诉了PyTorch如何从Dataset取多少数据,怎么取数据去组成一个mini batch,所以不需要以上参数

num_workers (int, optional): how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) 
多进程加载数据,默认为0,即采用主进程加载数据

collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s).  Used when using batched loading from a map-style dataset. 
聚集函数,用来对一个batch进行后处理,拿到一个batch的数据后进行什么处理,用这个参数定义,返回处理后的batch数据
常用默认:_utils.collate.default_collate,源码中进行了若干逻辑判断,仅将数据组合起来返回,没有实质性工作
默认collate_fn的声明是:def default_collate(batch): 所以自定义collate_fn需要以batch为输入,以处理后的batch为输出

pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them.  If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below.
用于将tensor加载到GPU中进行运算

drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``)
是否保存最后一个mini batch,样本数量可能不支持被batch size整除,所以drop_last参数决定是否保留最后一个可能批量较小的batch

timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``)
控制从进程中获取一个batch数据的时延

worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``)
初始化子进程

prefetch_factor (int, optional, keyword-only arg): Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. (default: ``2``)
控制样本在每个进程里的预加载,默认为2

persistent_workers (bool, optional): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``)
控制加载完一次Dataset是否保留进程,默认为False

DataLoader源码剖析

在DataLoader的__init__函数里,我们可以看到,它实现了:

  1. 构建Sampler,单样本
  2. 构建BatchSampler,组建batch
  3. 构建collate

默认参数逻辑:

if sampler is None:  # give default samplers
    if self._dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset, generator=generator)
        else:
            sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator

if collate_fn is None:
    if self._auto_collation:
        collate_fn = _utils.collate.default_collate
    else:
        collate_fn = _utils.collate.default_convert

self._dataset_kind == _DatasetKind.Iterable是在Dataset类是IterableDataset时才为True

if isinstance(dataset, IterableDataset):
	self._dataset_kind = _DatasetKind.Iterable

IterableDataset应用于数据集非常大,将其完全加载进内存不现实(例如高达几个TB的数据),这时就需要IterableDataset构建可迭代的Dataset类,自定义的Dataset需要继承自torch.util.data.IterableDataset,重写__iter__方法,返回可迭代对象(通常是yield生成器)

所以,对于IterableDataset来说,就没有构建采样器Sampler的需求,因为样本是通过调用__iter__一个个读取出来的。执行封装的DataLoader传进去的batch_size次__iter__方法,就获取到一个mini batch

IterableDataset对应的_InfiniteConstantSampler为:

class _InfiniteConstantSampler(Sampler):
    r"""Analogous to ``itertools.repeat(None, None)``.
    Used as sampler for :class:`~torch.utils.data.IterableDataset`.

    Args:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self):
        super(_InfiniteConstantSampler, self).__init__(None)

    def __iter__(self):
        while True:
            yield None

可以看到,__iter__方法返回None的生成器

所以,对于自定义的Dataset,如果shuffle为True,调用RandomSampler,否则为SequentialSampler

RandomSampler源码剖析

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            yield from torch.randperm(n, generator=generator).tolist()

    def __len__(self) -> int:
        return self.num_samples

我们主要关注__iter__方法,可以看到:

  1. n为数据集大小

  2. 如果指定了replacement参数为True,则需要指定num_samples参数,表示采样器需要返回的样本个数

    PyTorch源码里通过torch.randint以32为一批返回0~n-1的随机整数,每一批共计32个采样下标,共采样num_samples // 32批,最后一批的采样下标数为num_samples对32取余,所以最后的采样下标数总和为num_samples

  3. 如果保持默认的replacement参数为False,则通过torch.randperm(n)返回0~n-1的随机序列,共计n个采样下标

SequentialSampler源码剖析

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

SequentialSampler的__iter__方法返回顺序迭代器,每次调用__iter__方法即可返回顺序下标

BatchSampler源码剖析

class BatchSampler(Sampler[List[int]]):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]

BatchSampler需要传入一个其他的Sampler,用以将该Sampler生成的采样下标组装成mini batch的采样下标

我们重点关注BatchSampler的__iter__方法,可以看到:通过for循环调用sampler的iter方法,拿到一个采样下标放入batch列表里,直到batch列表的长度等于指定的batch size,返回batch对应的生成器,随后重置batch列表为空,再接着从sampler里继续取采样下标。如果drop_last为False并且最后一个batch有样本的话,就把最后一个不满batch size的采样下标生成器返回

__len__方法返回总共的batch数,即所有的样本被分成了多少个batch

default_collate源码剖析

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

default_collate大部分都在做合理性判断的工作,实质上是把所有相关的数据转换成tensor,把Dataset的__getitem__的对应数据组装后返回。例如:[(img0, label0), (img1, label1),(img2, label2), ] 整理成[[img0,img1,img2,], [label0,label1,label2,]],这里要求多个img的size相同(根据isinstance(elem, collections.abc.Sequence可以看出这就是为什么遍历DataLoader时,我们拿到的是列表数据)

collate_fn是对一个batch的数据做后处理,即结合BatchSampler给的mini batch采样下标,利用Dataset里的__getitem__(self, index)方法,取出一个batch的数据,然后传到collate_fn里进行处理。为了摸清collate_fn的运行机制,我们先去DataLoader源码的__iter__方法里看它是怎么取数据的

# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
    # When using a single worker the returned iterator should be
    # created everytime to avoid reseting its state
    # However, in the case of a multiple workers iterator
    # the iterator is only created once in the lifetime of the
    # DataLoader object so that workers can be reused
    if self.persistent_workers and self.num_workers > 0:
        if self._iterator is None:
            self._iterator = self._get_iterator()
        else:
            self._iterator._reset(self)
        return self._iterator
    else:
        return self._get_iterator()

可以看到,这里DataLoader在__iter__里调用_get_iterator方法创建迭代器,所以我们再去阅读_get_iterator方法

def _get_iterator(self) -> '_BaseDataLoaderIter':
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        self.check_worker_number_rationality()
        return _MultiProcessingDataLoaderIter(self)

我们先关注默认情况,即num_workers=0时,_get_iterator方法返回了一个_SingleProcessDataLoaderIter实例,而这个_SingleProcessDataLoaderIter实例继承自_BaseDataLoaderIter这个基类,_BaseDataLoaderIter类里实现了__iter__方法和__next__方法,用于对这个迭代器遍历取数据

可迭代对象实现了__iter__方法,支持重复遍历,但不支持next(可迭代对象),而迭代器不支持重复遍历,采用iter(可迭代对象)获取对应的迭代器,这时可以对其使用next方法。迭代器如果实现了__next__方法,就可以使用next(迭代器)返回迭代器的下一个值

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data

    next = __next__  # Python 2 compatibility

    def __len__(self) -> int:
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

在_SingleProcessDataLoaderIter里,我们可以看到:这个迭代器又实例化了一个数据获取器_dataset_fetcher,而这个数据获取器接收了collate_fn参数

_SingleProcessDataLoaderIter实现了_next_data方法,它先调用_next_index方法获取下一批采样下标,而这批采样下标就是从_BaseDataLoaderIter基类里的_next_index方法获得的,该方法调用了之前BatchSampler的迭代器(auto_collation为True的默认情况,因为DataLoader设置batch_size不为None时会创建BatchSampler,然后将_index_sampler设置为BatchSampler)获取下一批次采样下标

可以看到,在基类的__next__方法调用了_next_data方法获取下一批次数据

_SingleProcessDataLoaderIter的_next_data方法调用的是数据获取器_dataset_fetcher的fetch方法

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

而这里的fetch方法,如果auto_collation为True(设置了batch_size,自动创建了BatchSampler),就根据下一批的采样下标,从dataset里根据__getitem__组装数据,返回组装后的列表;否则,就依据Sampler(auto_collation为False时,前面的_index_sampler就为Sampler)的迭代器给出的单个采样下标,取dataset的一条数据

至此,我们终于见到collate_fn在此处被调用,这也明确了collate_fn确实起到了取出批次数据之后的处理作用

collate_fn输入数据在auto_collation为True时是一个列表,列表里的每个元素是Dataset的__getitem__返回的值,在auto_collation为False时,是Dataset的__getitem__返回的单条样例的数据类型

参考资源:

https://blog.csdn.net/weixin_35757704/article/details/119715900

https://www.daimajiaoliu.com/daima/4ede05ecd1003fc

https://blog.csdn.net/mieleizhi0522/article/details/82142856/

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值