DataLoader简介
PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。它表示数据集的 Python 可迭代对象, 支持
1.map样式和iterable样式的数据集;
map风格的数据集是实现__getitem__()和__len__()协议的数据集, 并表示从索引/键到数据样本的映射。
2.自定义数据加载顺序;
3.自动划分batch;
4.单进程和多进程数据加载;
5.自动内存固定。
DataLoader 参数作用
Dataset 一次检索一个样本和标签。在训练模型的时候, 我们需要一个batch的样本, 并且每个epoch样本随机打乱(降低过拟合可能性), 多线程加载数据。DataLoader 的工作流程可以参考PyTorch DataLoader工作原理可视化
DataLoader(dataset, # Dataset类的实例
batch_size=1, # 一个batch图片的个数
shuffle=False, # 是否训练完一个epoch, 数据重新打乱
sampler=None, # 自定义采样方法
batch_sampler=None, # 把sampler的采样的样本根据batch_size组织成一个batch返回
num_workers=0, # 加载数据进程的个数
collate_fn=None, # 自定方法对每个batch的数据进一步进行处理
pin_memory=False, # 是否将数据存储到显存
drop_last=False, # 最后一个batch数据个数不够是否丢弃
timeout=0,
worker_init_fn=None,
*,
prefetch_factor=2,
persistent_workers=False)
Sampler 参数
torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的顺序。它们表示数据集索引上的可迭代对象。例如, 在随机梯度渐变(SGD)的常见情况下, Sampler 可以随机排列索引列表并一次生成一个索引或者一次生成一个 mini-batch 的索引。
可以根据 DataLoader 的 shuffle 参数自动构造顺序的或乱序的采样器。或者, 用户可以使用 sampler 参数 来指定一个自定义sampler对象, 该对象每次都会产生下一个要获取的索引/键。每次产生一个 batch 索引列表的自定义 Sampler 可以作为 batch_sampler 参数。更加详细的讲解可以参考Sampler类与4种采样方式
Sequential Sampler(顺序采样)
Random Sampler(随机采样)
Subset Random Sampler(子集随机采样)
Weighted Random Sampler(加权随机采样)
批处理
1.自动批处理
当 batch_size(默认值1)不为 None 时, 数据加载器生成成批样本而不是单个样本。batch_size 和drop_last参数 用于指定数据加载器如何获取批量的数据集键。对于 Map样式 的数据集, 用户可以指定 batch_sampler, 它一次生成一个键列表。
在使用来自 sampler 的索引获取样本列表之后, 使用 collate_fn参数 传递的函数将样本列表整理成批。
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
2.不自动批处理
当 batch_size 和 batch_sampler 都为 None 时(batch_sampler的默认值已经为None), 自动批处理被禁用。从数据集中获得的每个样本都使用作为 collate_fn参数 传递的函数进行处理。
默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch张量, 并保持其他所有内容不变。
for index in sampler:
yield collate_fn(dataset[index])
collate_fn 参数
当禁用自动批处理时, 对每个单独的数据样本调用 collate_fn, 并从数据加载器迭代器产生输出。在这种情况下, 默认的 collate_fn 只是转换 PyTorch 张量中的 NumPy 数组。
当启用自动批处理时, 每次调用 collate_fn 时都返回带有数据样本列表。将输入样本整理成 batch, 以便从数据加载器迭代器生成。
默认的 collate_fn 有以下属性:
1.它总是添加一个新维度作为 batch 维度。
2.它会自动将 NumPy 数组和Python数值转换为 PyTorch Tensors。
3.它保留了数据结构, 例如, 如果每个样本都是一个字典, 它输出一个字典, 具有相同的键集, 但值的类型转换为 PyTorch Tensors (如果值不能转换为 PyTorch Tensors, 将转换为列表)。
用户可以使用自定义collate_fn来实现自定义批处理, 例如, 沿着第一个维度以外的维度进行排序, 填充各种长度的序列, 或者添加对自定义数据类型的支持。