简介
dataset:数据集,提供数据
dataloader:数据加载器,对数据进行加载,可以讲数据加载到神经网络当中
从dataset中取数据时,通过在dataloader中设置参数来确定取数据的方式
用法
要from torch.utils.data import DataLoader
参数:大部分参数都有默认值
Args:
dataset (Dataset): 指定数据集
batch_size (int, optional): how many samples per batch to load (default: ``1``). 每次加载多少数据
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``). 每次训练后是否进行无序操作(默认为False,一般设置为True)
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``) 每次加载时是否采用多进程加载(但是在windows下可能会出现错误)
collate_fn (Callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into device/CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``) 除不尽时是否舍去
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
``None``, the default `multiprocessing context`_ of your operating system will
be used. (default: ``None``)
generator (torch.Generator, optional): If not ``None``, this RNG will be used
by RandomSampler to generate random indexes and multiprocessing to generate
``base_seed`` for workers. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers batches prefetched across all workers. (default value depends
on the set value for num_workers. If value of num_workers=0 default is ``None``.
Otherwise, if value of ``num_workers > 0`` default is ``2``).
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
实践
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试数据集
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor())
# 加载测试集
test_loader = DataLoader(test_data,shuffle=True, batch_size = 64, num_workers=0, drop_last=False)
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter('logs')
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("epoch: {}".format(epoch), imgs, step)
step = step+1
writer.close
shuffle为True时,可以发现每轮训练数据的顺序是不一致的
drop_last=False时,则没有整除的数据保留,即最后一步会不足所设置的batch大小