参考 torch.utils.data - 云+社区 - 腾讯云
目录
Data Loading Order and Sampler
Loading Batched and Non-Batched Data
Single- and Multi-process Data Loading
Single-process data loading (default)
torch.utils.data
At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for
These options are configured by the constructor arguments of a DataLoader, which has signature:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
The sections below describe in details the effects and usages of these options.
Dataset Types
The most important argument of DataLoader constructor is dataset
, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:
Map-style datasets
A map-style dataset is one that implements the __getitem__()
and __len__()
protocols, and represents a map from (possibly non-integral) indices/keys to data samples.
For example, such a dataset, when accessed with dataset[idx]
, could read the idx
-th image and its corresponding label from a folder on the disk.
See Dataset for more details.
Iterable-style datasets
An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__()
protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.
For example, such a dataset, when called iter(dataset)
, could return a stream of data reading from a database, a remote server, or even logs generated in real time.
See IterableDataset for more details.
Note
When using an IterableDataset with multi-process data loading. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See IterableDataset documentations for how to achieve this.
Data Loading Order and Sampler
For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).
The rest of this section concerns the case with map-style datasets. torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.
A sequential or shuffled sampler will be automatically constructed based on the shuffle
argument to a DataLoader. Alternatively, users may use the sampler
argument to specify a custom Sampler object that at each time yields the next index/key to fetch.
A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler
argument. Automatic batching can also be enabled via batch_size
and drop_last
arguments. See the next section for more details on this.
Note
Neither sampler
nor batch_sampler
is compatible with iterable-style datasets, since such datasets have no notion of a key or an index.
Loading Batched and Non-Batched Data
DataLoader supports automatically collating individual fetched data samples into batches via arguments batch_size
, drop_last
, and batch_sampler
.
Automatic batching (default)
This is the most common case, and corresponds to fetching a minibatch of data and collating them into batched samples, i.e., containing Tensors with one dimension being the batch dimension (usually the first).
When batch_size
(default 1
) is not None
, the data loader yields batched samples instead of individual samples. batch_size
and drop_last
arguments are used to specify how the data loader obtains batches of dataset keys. For map-style datasets, users can alternatively specify batch_sampler
, which yields a list of keys at a time.
Note
The batch_size
and drop_last
arguments essentially are used to construct a batch_sampler
from sampler
. For map-style datasets, the sampler
is either provided by user or constructed based on the shuffle
argument. For iterable-style datasets, the sampler
is a dummy infinite one. See this section on more details on samplers.
Note
When fetching from iterable-style datasets with multi-processing, the drop_last
argument drops the last non-full batch of each worker’s dataset replica.
After fetching a list of samples using the indices from sampler, the function passed as the collate_fn
argument is used to collate lists of samples into batches.
In this case, loading from a map-style dataset is roughly equivalent with:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
and loading from an iterable-style dataset is roughly equivalent with:
dataset_iter = i