Pytorch Dataloader参数补充篇(collate_fn、pin_memory、dorp_last、prefetch_factor、persistent_workers)

目录

collate_fn (callable, 可选)

pin_memory (bool, 可选)

drop_last (bool, 可选)

timeout (numeric, optional)

worker_init_fn (callable, optional)

generator (torch.Generator, 可选)

prefetch_factor (int, optional, keyword-only arg)

persistent_workers (bool, optional)


这个参数是一个函数,用于将多个样本数据合并成一个batch。如果你的数据包含不同形状的数据,那么你可能需要自定义这个函数。

pin_memory (bool, 可选)

作用是决定是否将数据预先加载至内存中。如果pin_memory=True,DataLoader会在返回之前,将数据加载到CUDA Pinned Memory中。这样做的好处在于,当我们想要将数据移动到GPU设备进行计算时,从固定内存到显存的传输速度要比从常规内存到显存的速度快,由此可以加速数据传输。详解推荐阅读:https://zhuanlan.zhihu.com/p/561544545

drop_last (bool, 可选)

当数据总量不能整除batch_size时,最后会余一部分样本。例如数据集总长为10,batch=3,则最后会余1个样本。drop_last就是控制是否丢掉余下的样本。True则丢弃,False不丢。

timeout (numeric, optional)

此参数表示等待来自子进程的返回数据的最长时间,仅当使用多进程加载时才实际使用。这是一种在数据加载时使用的防止死锁的机制。如果timeout设置为正数,并且从工作进程收到数据的时间超过了timeout中定义的秒数,或者如果没有从子进程获取数据,则会引发一个运行时错误。

  1. 数据预处理过程中出现了问题:如果数据预处理函数(如collate_fn或自定义的数据转换函数)出现了错误或死循环,可能会导致数据加载器无法获取下一个数据批次。

  2. 数据加载过程中遇到了I/O问题:如果从磁盘或网络加载数据时遇到了I/O问题(如磁盘故障或网络中断),可能会导致数据加载器无法获取下一个数据批次。

  3. 多进程/多线程数据加载时出现了死锁或竞争条件:当使用多进程或多线程加载数据时,如果出现了死锁或竞争条件,也可能导致数据加载器无法获取下一个数据批次。

worker_init_fn (callable, optional)

在启动新的子进程时执行的自定义的初始化函数。如果不为None,那么在设定随机种子后,数据加载前,这个函数会在每个子进程上调用,并把子进程id作为输入(一个在[0, num_workers - 1]的整数),默认为None。

generator (torch.Generator, 可选)

用于提供一个可替代的随机数生成器,用于对数据进行混洗(shuffle)操作。默认情况下,DataLoader使用PyTorch内置的随机数生成器torch.random进行混洗操作。但在某些情况下,您可能需要使用自定义的随机数生成器,此时则需要自定义的generator。

  1. 如果您同时设置了worker_init_fn参数,那么在每个工作进程中,worker_init_fn函数中设置的随机数生成器种子将优先于generator参数。这是因为每个工作进程都需要使用不同的随机种子,以避免产生相同的混洗顺序。

prefetch_factor (int, optional, keyword-only arg)

每个子进程预先加载的样本数,默认为2。例如prefetch_factor=2,num_workers=4,所有的进程则会预先加载2*4=8个样本。

persistent_workers (bool, optional)

默认情况下,当num_workers>0时,PyTorch会为每个epoch创建新的工作进程用于数据加载。这意味着在每个epoch开始时,PyTorch都会销毁上一个epoch中使用的工作进程,并创建新的工作进程。这种方式可以确保每个epoch中的数据顺序是独立的(在创建多线程),但同时也会带来一些开销,例如创建和销毁进程的时间开销。如果将persistent_workers设置为True,PyTorch将在第一个epoch时创建工作进程,然后在后续的epoch中重用这些工作进程,而不是在每个epoch开始时重新创建。这种方式可以提高数据加载的效率,因为它避免了频繁创建和销毁进程的开销。

code:torch/utils/data/dataloader.py#L346-L359

persistent_workers code:

# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
    # since '_BaseDataLoaderIter' references 'DataLoader'.
    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

  • 37
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值