2. pytorch数据加载DataLoader

DataLoader简介

        PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。它表示数据集的 Python 可迭代对象, 支持
          1.map样式和iterable样式的数据集;
            map风格的数据集是实现__getitem__()和__len__()协议的数据集, 并表示从索引/键到数据样本的映射。
          2.自定义数据加载顺序;
          3.自动划分batch;
          4.单进程和多进程数据加载;
          5.自动内存固定。

DataLoader 参数作用

        Dataset 一次检索一个样本和标签。在训练模型的时候, 我们需要一个batch的样本, 并且每个epoch样本随机打乱(降低过拟合可能性), 多线程加载数据。DataLoader 的工作流程可以参考PyTorch DataLoader工作原理可视化

DataLoader(dataset, # Dataset类的实例
           batch_size=1, # 一个batch图片的个数
           shuffle=False, # 是否训练完一个epoch, 数据重新打乱
           sampler=None, # 自定义采样方法
           batch_sampler=None, # 把sampler的采样的样本根据batch_size组织成一个batch返回
           num_workers=0, # 加载数据进程的个数
           collate_fn=None, # 自定方法对每个batch的数据进一步进行处理
           pin_memory=False, # 是否将数据存储到显存
           drop_last=False, # 最后一个batch数据个数不够是否丢弃
           timeout=0,
           worker_init_fn=None, 
           *, 
           prefetch_factor=2,
           persistent_workers=False)

Sampler 参数

        torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的顺序。它们表示数据集索引上的可迭代对象。例如, 在随机梯度渐变(SGD)的常见情况下, Sampler 可以随机排列索引列表并一次生成一个索引或者一次生成一个 mini-batch 的索引。
        可以根据 DataLoader 的 shuffle 参数自动构造顺序的或乱序的采样器。或者, 用户可以使用 sampler 参数 来指定一个自定义sampler对象, 该对象每次都会产生下一个要获取的索引/键。每次产生一个 batch 索引列表的自定义 Sampler 可以作为 batch_sampler 参数。更加详细的讲解可以参考Sampler类与4种采样方式
        Sequential Sampler(顺序采样)
        Random Sampler(随机采样)
        Subset Random Sampler(子集随机采样)
        Weighted Random Sampler(加权随机采样)

批处理

    1.自动批处理
        当 batch_size(默认值1)不为 None 时, 数据加载器生成成批样本而不是单个样本。batch_size 和drop_last参数 用于指定数据加载器如何获取批量的数据集键。对于 Map样式 的数据集, 用户可以指定 batch_sampler, 它一次生成一个键列表。
        在使用来自 sampler 的索引获取样本列表之后, 使用 collate_fn参数 传递的函数将样本列表整理成批。

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

    2.不自动批处理
        当 batch_size 和 batch_sampler 都为 None 时(batch_sampler的默认值已经为None), 自动批处理被禁用。从数据集中获得的每个样本都使用作为 collate_fn参数 传递的函数进行处理。
        默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch张量, 并保持其他所有内容不变。

for index in sampler:
    yield collate_fn(dataset[index])

collate_fn 参数

        当禁用自动批处理时, 对每个单独的数据样本调用 collate_fn, 并从数据加载器迭代器产生输出。在这种情况下, 默认的 collate_fn 只是转换 PyTorch 张量中的 NumPy 数组。
        当启用自动批处理时, 每次调用 collate_fn 时都返回带有数据样本列表。将输入样本整理成 batch, 以便从数据加载器迭代器生成。

        默认的 collate_fn 有以下属性:
         1.它总是添加一个新维度作为 batch 维度。
         2.它会自动将 NumPy 数组和Python数值转换为 PyTorch Tensors。
         3.它保留了数据结构, 例如, 如果每个样本都是一个字典, 它输出一个字典, 具有相同的键集, 但值的类型转换为 PyTorch Tensors (如果值不能转换为 PyTorch Tensors, 将转换为列表)。
        用户可以使用自定义collate_fn来实现自定义批处理, 例如, 沿着第一个维度以外的维度进行排序, 填充各种长度的序列, 或者添加对自定义数据类型的支持。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值