Pytorch Dataloader map-style dataset

import numpy as np
import torch
import torch.utils.data as data

data_tensor=torch.Tensor(range(1,17)).reshape(4,1,2,2)
label_tensor=torch.Tensor([0,1,1,0])
dataset=data.TensorDataset(data_tensor,label_tensor)

for da,label in dataset:
    print(da,label)


out: 
tensor([[[1., 2.],
         [3., 4.]]]) tensor(0.)
tensor([[[5., 6.],
         [7., 8.]]]) tensor(1.)
tensor([[[ 9., 10.],
         [11., 12.]]]) tensor(1.)
tensor([[[13., 14.],
         [15., 16.]]]) tensor(0.)

len(dataset) #4
#浅拷贝
subset_dataset=data.Subset(dataset,[0,2])
subset_dataset[0]
subset_dataset[2]
id(subset_dataset[0])


out:
(tensor([[[1., 2.],
         [3., 4.]]]) ,tensor(0.))
(tensor([[[ 9., 10.],
         [11., 12.]]]) ,tensor(1.))

subset_dataset2=data.Subset(dataset,[1,3])
dataset_=subset_dataset+subset_dataset2 #调用ConcatDataset([self, other])
for da,label in dataset_:
    print(da,label)


out:
tensor([[[1., 2.],
         [3., 4.]]]) tensor(0.)
tensor([[[ 9., 10.],
         [11., 12.]]]) tensor(1.)
tensor([[[5., 6.],
         [7., 8.]]]) tensor(1.)
tensor([[[13., 14.],
         [15., 16.]]]) tensor(0.)

random_split_datasets=data.random_split(dataset,[1,2,1])
for i in random_split_datasets:
    for da,label in i:
        print(da,label)
    print('---'*5)

out:
tensor([[[ 9., 10.],
         [11., 12.]]]) tensor(1.)
---------------
tensor([[[5., 6.],
         [7., 8.]]]) tensor(1.)
tensor([[[1., 2.],
         [3., 4.]]]) tensor(0.)
---------------
tensor([[[13., 14.],
         [15., 16.]]]) tensor(0.)
---------------




'''
Disable automatic batching
In certain cases, users may want to handle batching manually in dataset code, 
or simply load individual samples. For example, it could cheaper to directly load batched data
 (e.g., bulk reads from a database or reading continuous chunks of memory), or the batch size is 
 data dependent, or the program is designed to work on individual samples. Under these scenarios,
 it’s likely better to not use automatic batching (where collate_fn is used to collate the samples),
 but let the data loader directly return each member of the dataset object.

When both batch_size and batch_sampler are None (default value for batch_sampler is already None),
 automatic batching is disabled. Each sample obtained from the dataset is processed with the 
 function passed as the collate_fn argument.

When automatic batching is disabled, the default collate_fn simply converts NumPy arrays 
into PyTorch Tensors, and keeps everything else untouched.

In this case, loading from a map-style dataset is roughly equivalent with:

for index in sampler:
    yield collate_fn(dataset[index])
'''
#batch_size=None and batch_sampler=None,disable auto_batching,
#default sampler-> sampler = SequentialSampler(dataset)
#
dataloader1=data.DataLoader(dataset,batch_size=None)
#batch_size=None and batch_sampler=None,disable auto_batching,
#sampler-> RandomSample
dataloader2=data.DataLoader(dataset,batch_size=None,\
                            sampler=data.RandomSampler(dataset,replacement=False, num_samples=None))

#batch_size默认为1,batch_size提供的话,可以提供drop_last,shuffle:sampler = RandomSampler(dataset)
dataloader3=data.DataLoader(dataset,batch_size=3,shuffle=True)
#sampler = SequentialSampler(dataset),drop_last 丢掉不够batch_size的最后几个
dataloader4=data.DataLoader(dataset,batch_size=3,shuffle=False,drop_last=True)
#sampler = SequentialSampler(dataset)
dataloader5=data.DataLoader(dataset,batch_size=3,shuffle=False,drop_last=False)
'''sampler与shuffle是互斥的,batch_sampler与batch_size,shuffle,drop_last,sample是互斥的,
若batch_sampler==None,则DataLoader内部用其余的参数构造一个batch_sampler
'''
batch_sampler=data.BatchSampler(sampler=data.RandomSampler(dataset),\
                                batch_size=2, drop_last=False)
dataloader6=data.DataLoader(dataset,batch_sampler=batch_sampler)

dataloader=[]
dataloader.append(dataloader1)
dataloader.append(dataloader2)
dataloader.append(dataloader3)
dataloader.append(dataloader4)
dataloader.append(dataloader5)
dataloader.append(dataloader6)
for dataloader_ in dataloader:
    for d in dataloader_:
        print(d)
        print('***'*20)
    print('---'*20)

'''
Dataloader利用sampler返回的索引,调用featch从dataset里取出data后再collate_fn,把
d:[(ima1,lab1),(ima2,lab2),(ima3,lab3)]的形式转换成d:[(ima1,ima2,ima3),(lab1,lab2,lab3)]的形式。
'''

[tensor([[[1., 2.],
         [3., 4.]]]), tensor(0.)]
************************************************************
[tensor([[[5., 6.],
         [7., 8.]]]), tensor(1.)]
************************************************************
[tensor([[[ 9., 10.],
         [11., 12.]]]), tensor(1.)]
************************************************************
[tensor([[[13., 14.],
         [15., 16.]]]), tensor(0.)]
************************************************************
------------------------------------------------------------
[tensor([[[1., 2.],
         [3., 4.]]]), tensor(0.)]
************************************************************
[tensor([[[5., 6.],
         [7., 8.]]]), tensor(1.)]
************************************************************
[tensor([[[13., 14.],
         [15., 16.]]]), tensor(0.)]
************************************************************
[tensor([[[ 9., 10.],
         [11., 12.]]]), tensor(1.)]
************************************************************
------------------------------------------------------------
[tensor([[[[ 1.,  2.],
          [ 3.,  4.]]],


        [[[ 5.,  6.],
          [ 7.,  8.]]],


        [[[13., 14.],
          [15., 16.]]]]), tensor([0., 1., 0.])]
************************************************************
[tensor([[[[ 9., 10.],
          [11., 12.]]]]), tensor([1.])]
************************************************************
------------------------------------------------------------
[tensor([[[[ 1.,  2.],
          [ 3.,  4.]]],


        [[[ 5.,  6.],
          [ 7.,  8.]]],


        [[[ 9., 10.],
          [11., 12.]]]]), tensor([0., 1., 1.])]
************************************************************
------------------------------------------------------------
[tensor([[[[ 5.,  6.],
          [ 7.,  8.]]],


        [[[ 1.,  2.],
          [ 3.,  4.]]],


        [[[13., 14.],
          [15., 16.]]]]), tensor([1., 0., 0.])]
************************************************************
[tensor([[[[ 9., 10.],
          [11., 12.]]]]), tensor([1.])]
************************************************************
------------------------------------------------------------
[tensor([[[[ 5.,  6.],
          [ 7.,  8.]]],


        [[[13., 14.],
          [15., 16.]]]]), tensor([1., 0.])]
************************************************************
[tensor([[[[ 1.,  2.],
          [ 3.,  4.]]],


        [[[ 9., 10.],
          [11., 12.]]]]), tensor([0., 1.])]
************************************************************
------------------------------------------------------------

dataset.py

import bisect
import warnings

from torch._utils import _accumulate
from torch import randperm


class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py


class IterableDataset(Dataset):
    r"""An iterable Dataset.

    All datasets that represent an iterable of data samples should subclass it.
    Such form of datasets is particularly useful when data come from a stream.

    All subclasses should overrite :meth:`__iter__`, which would return an
    iterator of samples in this dataset.

    When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
    item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
    iterator. When :attr:`num_workers > 0`, each worker process will have a
    different copy of the dataset object, so it is often desired to configure
    each copy independently to avoid having duplicate data returned from the
    workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
    process, returns information about the worker. It can be used in either the
    dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
    :attr:`worker_init_fn` option to modify each copy's behavior.

    Example 1: splitting workload across all workers in :meth:`__iter__`::

        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
        ...     def __init__(self, start, end):
        ...         super(MyIterableDataset).__init__()
        ...         assert end > start, "this example code only works with end >= start"
        ...         self.start = start
        ...         self.end = end
        ...
        ...     def __iter__(self):
        ...         worker_info = torch.utils.data.get_worker_info()
        ...         if worker_info is None:  # single-process data loading, return the full iterator
        ...             iter_start = self.start
        ...             iter_end = self.end
        ...         else:  # in a worker process
        ...             # split workload
        ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
        ...             worker_id = worker_info.id
        ...             iter_start = self.start + worker_id * per_worker
        ...             iter_end = min(iter_start + per_worker, self.end)
        ...         return iter(range(iter_start, iter_end))
        ...
        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
        >>> ds = MyIterableDataset(start=3, end=7)

        >>> # Single-process loading
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
        [3, 4, 5, 6]

        >>> # Mult-process loading with two worker processes
        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
        [3, 5, 4, 6]

        >>> # With even more workers
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
        [3, 4, 5, 6]

    Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
        ...     def __init__(self, start, end):
        ...         super(MyIterableDataset).__init__()
        ...         assert end > start, "this example code only works with end >= start"
        ...         self.start = start
        ...         self.end = end
        ...
        ...     def __iter__(self):
        ...         return iter(range(self.start, self.end))
        ...
        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
        >>> ds = MyIterableDataset(start=3, end=7)

        >>> # Single-process loading
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
        [3, 4, 5, 6]
        >>>
        >>> # Directly doing multi-process loading yields duplicate data
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
        [3, 3, 4, 4, 5, 5, 6, 6]

        >>> # Define a `worker_init_fn` that configures each dataset copy differently
        >>> def worker_init_fn(worker_id):
        ...     worker_info = torch.utils.data.get_worker_info()
        ...     dataset = worker_info.dataset  # the dataset copy in this worker process
        ...     overall_start = dataset.start
        ...     overall_end = dataset.end
        ...     # configure the dataset to only process the split workload
        ...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
        ...     worker_id = worker_info.id
        ...     dataset.start = overall_start + worker_id * per_worker
        ...     dataset.end = min(dataset.start + per_worker, overall_end)
        ...

        >>> # Mult-process loading with the custom `worker_init_fn`
        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
        [3, 5, 4, 6]

        >>> # With even more workers
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
        [3, 4, 5, 6]
    """

    def __iter__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ChainDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]


class TensorDataset(Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    #对于可变参数,传进来后会封装成元组,因此下面len是self.tensors[0].size(0),而且self.tensors和tensors是两个对象,不是两
    #个引用指向同一个对象。
    def __init__(self, *tensors):  
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors) #没有返回标签

    def __len__(self):
        return self.tensors[0].size(0)


class ConcatDataset(Dataset):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes


class ChainDataset(IterableDataset):
    r"""Dataset for chainning multiple :class:`IterableDataset` s.

    This class is useful to assemble different existing dataset streams. The
    chainning operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.

    Arguments:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets):
        super(ChainDataset, self).__init__()
        self.datasets = datasets

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

    def __len__(self):
        total = 0
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            total += len(d)
        return total


class Subset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]] #双重数组实现subset[0]=dataset[subset.indices[0]]=...

    def __len__(self):
        return len(self.indices)


def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

dataloader.py

import torch
import multiprocessing as python_multiprocessing
import torch.multiprocessing as multiprocessing
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
from . import _utils
from torch._utils import ExceptionWrapper
import threading
import itertools
from torch._six import queue, string_classes


get_worker_info = _utils.worker.get_worker_info

# This function used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate = _utils.collate.default_collate


class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)##########王文
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)


class _InfiniteConstantSampler(Sampler):
    r"""Analogous to ``itertools.repeat(None, None)``.
    Used as sampler for :class:`~torch.utils.data.IterableDataset`.
    """

    def __init__(self):
        super(_InfiniteConstantSampler, self).__init__(None)

    def __iter__(self):
        while True:
            yield None

    def __len__(self):
        # This has to be a TypeError, otherwise, since this is used in
        # `len(dataloader)`, `list(dataloader)` will fail.
        # see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        raise TypeError('Cannot determine the DataLoader length of a IterableDataset')


class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.

    The :class:`~torch.utils.data.DataLoader` supports both map-style and
    iterable-style datasets with single- or multi-process loading, customizing
    loading order and optional automatic batching (collation) and memory pinning.

    See :py:mod:`torch.utils.data` documentation page for more details.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, :attr:`shuffle` must be ``False``.
        batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
            indices at a time. Mutually exclusive with :attr:`batch_size`,
            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.
        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
            into CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
            see the example below.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)


    .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
                 cannot be an unpicklable object, e.g., a lambda function. See
                 :ref:`multiprocessing-best-practices` on more details related
                 to multiprocessing in PyTorch.

    .. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
              When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
              an infinite sampler is used, whose :meth:`__len__` is not
              implemented, because the actual length depends on both the
              iterable as well as multi-process loading configurations. So one
              should not query this method unless they work with a map-style
              dataset. See `Dataset Types`_ for more details on these two types
              of datasets.
    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
        torch._C._log_api_usage_once("python.data_loader")

        if num_workers < 0:
            raise ValueError('num_workers option should be non-negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        self.dataset = dataset
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.multiprocessing_context = multiprocessing_context

        # Arg-check dataset related before checking samplers because we want to
        # tell users that iterable-style datasets are incompatible with custom
        # samplers first, so that they don't learn that this combo doesn't work
        # after spending time fixing the custom sampler errors.
        if isinstance(dataset, IterableDataset):
            self._dataset_kind = _DatasetKind.Iterable
            # NOTE [ Custom Samplers and `IterableDataset` ]
            #
            # `IterableDataset` does not support custom `batch_sampler` or
            # `sampler` since the key is irrelevant (unless we support
            # generator-style dataset one day...).
            #
            # For `sampler`, we always create a dummy sampler. This is an
            # infinite sampler even when the dataset may have an implemented
            # finite `__len__` because in multi-process data loading, naive
            # settings will return duplicated data (which may be desired), and
            # thus using a sampler with length matching that of dataset will
            # cause data lost (you may have duplicates of the first couple
            # batches, but never see anything afterwards). Therefore,
            # `Iterabledataset` always uses an infinite sampler, an instance of
            # `_InfiniteConstantSampler` defined above.
            #
            # A custom `batch_sampler` essentially only controls the batch size.
            # However, it is unclear how useful it would be since an iterable-style
            # dataset can handle that within itself. Moreover, it is pointless
            # in multi-process data loading as the assignment order of batches
            # to workers is an implementation detail so users can not control
            # how to batchify each worker's iterable. Thus, we disable this
            # option. If this turns out to be useful in future, we can re-enable
            # this, and support custom samplers that specify the assignments to
            # specific workers.
            if shuffle is not False:
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "shuffle option, but got shuffle={}".format(shuffle))
            elif sampler is not None:
                # See NOTE [ Custom Samplers and IterableDataset ]
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "sampler option, but got sampler={}".format(sampler))
            elif batch_sampler is not None:
                # See NOTE [ Custom Samplers and IterableDataset ]
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
                    
                    
        ##王文
        else:
            self._dataset_kind = _DatasetKind.Map

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if batch_sampler is not None:
            # auto_collation with custom batch_sampler
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False
            
            #??
        elif batch_size is None:
            # no auto_collation
            if shuffle or drop_last:
                raise ValueError('batch_size=None option disables auto-batching '
                                 'and is mutually exclusive with '
                                 'shuffle, and drop_last')
        #王文
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
        ################这这-----------------------
        #The batch_size and drop_last arguments essentially are used to construct a batch_sampler from sampler.
        #For map-style datasets, the sampler is either provided by user or constructed based on the shuffle argument.
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler #sample如果用户提供,跳过212-220,如果用户不提供,则根据shuffle参数,决定是random还是sequential
        self.batch_sampler = batch_sampler #真正的取样器

        if collate_fn is None:
            if self._auto_collation:
                collate_fn = _utils.collate.default_collate
            else:
                collate_fn = _utils.collate.default_convert

        self.collate_fn = collate_fn
        self.__initialized = True

    @property
    def multiprocessing_context(self):
        return self.__multiprocessing_context

    @multiprocessing_context.setter
    def multiprocessing_context(self, multiprocessing_context):
        if multiprocessing_context is not None:
            if self.num_workers > 0:
                if not multiprocessing._supports_context:
                    raise ValueError('multiprocessing_context relies on Python >= 3.4, with '
                                     'support for different start methods')

                if isinstance(multiprocessing_context, string_classes):
                    valid_start_methods = multiprocessing.get_all_start_methods()
                    if multiprocessing_context not in valid_start_methods:
                        raise ValueError(
                            ('multiprocessing_context option '
                             'should specify a valid start method in {}, but got '
                             'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
                    multiprocessing_context = multiprocessing.get_context(multiprocessing_context)

                if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
                    raise ValueError(('multiprocessing_context option should be a valid context '
                                      'object or a string specifying the start method, but got '
                                      'multiprocessing_context={}').format(multiprocessing_context))
            else:
                raise ValueError(('multiprocessing_context can only be used with '
                                  'multi-process loading (num_workers > 0), but got '
                                  'num_workers={}').format(self.num_workers))

        self.__multiprocessing_context = multiprocessing_context

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

    @property
    def _auto_collation(self):
        return self.batch_sampler is not None

    @property
    def _index_sampler(self):
        # The actual sampler used for generating indices for `_DatasetFetcher`
        # (see _utils/fetch.py) to read data at each time. This would be
        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
        # We can't change `.sampler` and `.batch_sampler` attributes for BC
        # reasons.
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler  #如果batch_size为None,则Disable automatic batching

    def __len__(self):
        return len(self._index_sampler)  # with iterable-style dataset, this will error


class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler#王文
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)#王文 iter(batch_sampler)或iter(sampler) 迭代器
        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()

    def __iter__(self):
        return self

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def __next__(self):
        raise NotImplementedError

    def __len__(self):
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    ##############王文
    def __next__(self):
        index = self._next_index()  # may raise StopIteration  王文batch_sampler->e.g. [1,2,0]  sample->e.g. 3
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration  collate_fn([dataset[1],dataset[2],dataset[0]])
        if self._pin_memory:         #接上一行=collate_fn([(ima1,lab1),(ima2,lab2),(ima3,lab3)])->[(ima1,ima2,ima3),(lab1,lab2,lab3)]
            data = _utils.pin_memory.pin_memory(data)
        return data

fetch.py

r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""


class _BaseDatasetFetcher(object):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, possibly_batched_index):
        raise NotImplementedError()


class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)


class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:  #batch_sampler is not None
            data = [self.dataset[idx] for idx in possibly_batched_index] #possibly_batched_index->e.g. [1,2,0]
        else:
            data = self.dataset[possibly_batched_index]  #possibly_batched_index   e.g. 3
        return self.collate_fn(data)  

collate.py

r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).

These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""

import torch
import re
from torch._six import container_abcs, string_classes, int_classes

np_str_obj_array_pattern = re.compile(r'[SaUO]')


def default_convert(data):
    r"""Converts each NumPy array data field into a tensor"""
    #default_convert((np.array([[1, 2],[3, 4]]),0))->[tensor([[1, 2],[3, 4]], dtype=torch.int32), 0]
    elem_type = type(data)
    if isinstance(data, torch.Tensor):
        return data
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        # array of string classes and object
        if elem_type.__name__ == 'ndarray' \
                and np_str_obj_array_pattern.search(data.dtype.str) is not None:
            return data
        return torch.as_tensor(data)
    elif isinstance(data, container_abcs.Mapping):
        return {key: default_convert(data[key]) for key in data}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return elem_type(*(default_convert(d) for d in data))
    elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
        return [default_convert(d) for d in data] #执行的是这一行
    else:
        return data


default_collate_err_msg_format = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}")


def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    #([('w',12),('e',13),('l',14)]->[('w', 'e', 'l'), tensor([12, 13, 14])] 'w'可以是numpy

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Iterable-style 数据集是一种在 Python 中用于处理大型数据集的方式。它是通过实现可迭代对象(Iterable)的方式来定义的。简单来说,可迭代对象是指可以使用迭代器进行遍历的对象。 在 Python 中,我们可以使用以下方式来创建 Iterable-style 数据集: 1. 自定义类:我们可以创建一个自定义的类,并实现 `__iter__()` 方法。在这个方法中,我们可以使用生成器(generator)来逐个产生数据样本。例如: ```python class MyIterableDataset: def __iter__(self): for i in range(10): yield i ``` 2. 使用生成器函数:我们可以使用生成器函数来创建 Iterable-style 数据集。生成器函数是一种特殊的函数,它使用 `yield` 语句来产生数据样本。例如: ```python def my_generator(): for i in range(10): yield i my_dataset = my_generator() ``` 3. 使用 itertools 模块:Python 的 itertools 模块提供了一些用于创建可迭代对象的函数。例如,`itertools.count()` 函数可以生成一个无限迭代器,用于产生连续的整数。我们可以使用 `itertools.islice()` 函数来限制迭代次数。例如: ```python import itertools my_dataset = itertools.islice(itertools.count(), 10) ``` Iterable-style 数据集提供了一种灵活的方式来处理大型数据集,因为它允许我们按需生成数据,而不需要一次性加载整个数据集到内存中。这对于处理大型数据集和无法一次性加载到内存的数据集非常有用。 相关问题: 1. Iterable-style 数据集相比于 map-style 数据集有什么优势和劣势? 2. 如何使用 Iterable-style 数据集来进行数据预处理和数据增强操作? 3. Iterable-style 数据集如何与 PyTorchDataLoader 结合使用? 4. 如何在 Iterable-style 数据集中实现数据的并行加载? 5. Iterable-style 数据集适用于哪些场景?

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值