Pytorch Dataloader入门

在这里插入图片描述

Pytorch Dataloader code:torch/utils/data/dataloader.py#L71
Pytorch Dataset tutorial: tutorials/beginner/basics/data_tutorial.html


理论:

在训练模型时,我们通常希望:

  1. 以“mini-batch”方式传递样本,能够加速训练。
  2. 每个epoch都shuffle数据,能够减少模型过拟合。
  3. 使用 Python 的多进程multiprocessing 处理,能够加速数据读取。

Dataloader就是解决上述事情的,PyTorch的DataLoader用于封装数据集(Dataset),并提供数据打乱、并行加载等功能。以下是其主要功能和特性:

  1. batch_size:这是一个非常重要的参数,用于指定每个批次的数据大小。在实际的训练过程中,我们通常一次处理一个batch的数据,然后更新参数。
  2. shuffle:这个参数用于表示是否要在每个epoch开始的时候打乱数据集。这对于训练过程中的模型泛化是非常有帮助的。(通常train_dataloader是shuffle的,test_dataloader是不shuffle的)
  3. num_workers:这个参数用于指定加载数据时需要使用的子进程数量。使用更多的子进程可以减少数据加载的时间,但同时也会增加内存使用量。

除了上述三个主要用到的功能,Dataloader的全部输入参数详细解释如下:


参数:

  1. dataset (Dataset):用于从中加载数据的数据集。
  2. batch_size (int, 可选):每批加载的样本数, 默认为1。
  3. shuffle (bool, 可选):是否在每个epoch重新打乱数据,默认为False,不打乱。
  4. sampler (Sampler or Iterable, 可选):定义从数据集中采样的策略。可以是任何实现了 len方法的 Iterable。与shuffle参数是冲突的,两个只能指定一个。
  5. batch_sampler (Sampler or Iterable, 可选):类似于sampler,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last参数互斥。
  6. num_workers (int, 可选):用于数据加载的子进程数。0表示在主进程中加载数据,默认0。
  7. collate_fn (callable, 可选):这个参数是一个函数,用于将多个样本数据合并成一个batch。如果你的数据包含不同形状的数据,那么你可能需要自定义这个函数。
  8. pin_memory (bool, 可选):作用是决定是否将数据预先加载至内存中。如果pin_memory=True,DataLoader会在返回之前,将数据加载到CUDA Pinned Memory中。这样做的好处在于,当我们想要将数据移动到GPU设备进行计算时,从固定内存到显存的传输速度要比从常规内存到显存的速度快,由此可以加速数据传输。详解:https://zhuanlan.zhihu.com/p/561544545
  9. drop_last (bool, 可选):当数据总量不能整除batch_size时,最后会余一部分样本。例如数据集总长为10,batch=3,则最后会余1个样本。drop_last就是控制是否丢掉余下的样本。True则丢弃,False不丢。
  10. timeout (numeric, optional):此参数表示等待来自子进程的返回数据的最长时间,仅当使用多进程加载时才实际使用。这是一种在数据加载时使用的防止死锁的机制。如果timeout设置为正数,并且从工作进程收到数据的时间超过了timeout中定义的秒数,或者如果没有从子进程获取数据,则会引发一个运行时错误。
    1. 数据预处理过程中出现了问题:如果数据预处理函数(如collate_fn或自定义的数据转换函数)出现了错误或死循环,可能会导致数据加载器无法获取下一个数据批次。
    2. 数据加载过程中遇到了I/O问题:如果从磁盘或网络加载数据时遇到了I/O问题(如磁盘故障或网络中断),可能会导致数据加载器无法获取下一个数据批次。
    3. 多进程/多线程数据加载时出现了死锁或竞争条件:当使用多进程或多线程加载数据时,如果出现了死锁或竞争条件,也可能导致数据加载器无法获取下一个数据批次。
  11. worker_init_fn (callable, optional):在启动新的子进程时执行的自定义的初始化函数。如果不为None,那么在设定随机种子后,数据加载前,这个函数会在每个子进程上调用,并把子进程id作为输入(一个在[0, num_workers - 1]的整数),默认为None。
  12. generator (torch.Generator, 可选):用于提供一个可替代的随机数生成器,用于对数据进行混洗(shuffle)操作。默认情况下,DataLoader使用PyTorch内置的随机数生成器torch.random进行混洗操作。但在某些情况下,您可能需要使用自定义的随机数生成器,此时则需要自定义的generator。
    1. 如果您同时设置了worker_init_fn参数,那么在每个工作进程中,worker_init_fn函数中设置的随机数生成器种子将优先于generator参数。这是因为每个工作进程都需要使用不同的随机种子,以避免产生相同的混洗顺序。
  13. prefetch_factor (int, optional, keyword-only arg):每个子进程预先加载的样本数,默认为2。例如prefetch_factor=2,num_workers=4,所有的进程则会预先加载2*4=8个样本。
  14. persistent_workers (bool, optional):默认情况下,当num_workers>0时,PyTorch会为每个epoch创建新的工作进程用于数据加载。这意味着在每个epoch开始时,PyTorch都会销毁上一个epoch中使用的工作进程,并创建新的工作进程。这种方式可以确保每个epoch中的数据顺序是独立的(在创建多线程),但同时也会带来一些开销,例如创建和销毁进程的时间开销。如果将persistent_workers设置为True,PyTorch将在第一个epoch时创建工作进程,然后在后续的epoch中重用这些工作进程,而不是在每个epoch开始时重新创建。这种方式可以提高数据加载的效率,因为它避免了频繁创建和销毁进程的开销。

代码:

下面是Dataloader读取一个mini-batch数据的代码样例,展示了Dataset和Dataloader的常用功能:(可直接执行)

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data_len):
        self.data_len = data_len

    def __getitem__(self, idx):
        return f"Data at index {idx}"

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    data_len = 10

    dataset = CustomDataset(data_len)
    print("Dataset:")
    for i in range(data_len):
        data_item = dataset[i]
        print(data_item)

    print("0. DataLoader(batch_size=1, shuffle=False, drop_last=False):")
    dataloader = DataLoader(dataset)
    for batch in dataloader:
        print(batch)

    print("\n1. DataLoader(batch_size=4):")
    dataloader = DataLoader(dataset, batch_size=4)
    for batch in dataloader:
        print(batch)

    print("\n2. DataLoader(batch_size=4, shuffle=True):")
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    for batch in dataloader:
        print(batch)

    print("\n3. DataLoader(batch_size=4, shuffle=True, drop_last=True):")
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)
    for batch in dataloader:
        print(batch)
        
----------------------------------------------------------------------------------------------------------
Dataset:    # Dataset每次只能读取一个样本
Data at index 0
Data at index 1
Data at index 2
Data at index 3
Data at index 4
Data at index 5
Data at index 6
Data at index 7
Data at index 8
Data at index 9

# Dataloader默认的参数
0. DataLoader(batch_size=1, shuffle=False, drop_last=False):
['Data at index 0']
['Data at index 1']
['Data at index 2']
['Data at index 3']
['Data at index 4']
['Data at index 5']
['Data at index 6']
['Data at index 7']
['Data at index 8']
['Data at index 9']

# Dataloader跟返回值相关的参数
1. DataLoader(batch_size=4):
['Data at index 0', 'Data at index 1', 'Data at index 2', 'Data at index 3']
['Data at index 4', 'Data at index 5', 'Data at index 6', 'Data at index 7']
['Data at index 8', 'Data at index 9']

2. DataLoader(batch_size=4, shuffle=True):
['Data at index 8', 'Data at index 3', 'Data at index 2', 'Data at index 1']
['Data at index 4', 'Data at index 0', 'Data at index 6', 'Data at index 9']
['Data at index 7', 'Data at index 5']

3. DataLoader(batch_size=4, shuffle=True, drop_last=True):
['Data at index 6', 'Data at index 5', 'Data at index 7', 'Data at index 2']
['Data at index 1', 'Data at index 3', 'Data at index 8', 'Data at index 0']


总结:

上述代码中展示了Dataset和Dataloader的基本功能:

  1. line43-line53:Dataset每次只能读取一个样本,而且是顺序的。
  2. line55-line66:Dataloader的默认参数:batch_size=1, shuffle=False, drop_last=False,每次返回的是一个mini-batch的数据,类型是一个列表,默认长度为1。
  3. line69-line72:Dataloader修改参数batch_size=4,返回的mini-batch数据中长度变为4,而且是顺序的。
  4. line74-line77:Dataloader修改参数shuffle=True,会改变数据的返回循序,是随机打乱的。
  5. line79-line81:Dataloader修改参数drop_last=True,会丢弃掉最后样本数无法被batch_size整除的样本。代码中总共10个样本,batch_size=4,10%4=2,打乱后,最后2个样本被丢弃了。

嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

  • 29
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值