目录
worker_init_fn (callable, optional)
generator (torch.Generator, 可选)
prefetch_factor (int, optional, keyword-only arg)
persistent_workers (bool, optional)
这个参数是一个函数,用于将多个样本数据合并成一个batch。如果你的数据包含不同形状的数据,那么你可能需要自定义这个函数。
pin_memory (bool, 可选)
作用是决定是否将数据预先加载至内存中。如果pin_memory=True,DataLoader会在返回之前,将数据加载到CUDA Pinned Memory中。这样做的好处在于,当我们想要将数据移动到GPU设备进行计算时,从固定内存到显存的传输速度要比从常规内存到显存的速度快,由此可以加速数据传输。详解推荐阅读:https://zhuanlan.zhihu.com/p/561544545
drop_last (bool, 可选)
当数据总量不能整除batch_size时,最后会余一部分样本。例如数据集总长为10,batch=3,则最后会余1个样本。drop_last就是控制是否丢掉余下的样本。True则丢弃,False不丢。
timeout (numeric, optional)
此参数表示等待来自子进程的返回数据的最长时间,仅当使用多进程加载时才实际使用。这是一种在数据加载时使用的防止死锁的机制。如果timeout设置为正数,并且从工作进程收到数据的时间超过了timeout中定义的秒数,或者如果没有从子进程获取数据,则会引发一个运行时错误。
-
数据预处理过程中出现了问题:如果数据预处理函数(如collate_fn或自定义的数据转换函数)出现了错误或死循环,可能会导致数据加载器无法获取下一个数据批次。
-
数据加载过程中遇到了I/O问题:如果从磁盘或网络加载数据时遇到了I/O问题(如磁盘故障或网络中断),可能会导致数据加载器无法获取下一个数据批次。
-
多进程/多线程数据加载时出现了死锁或竞争条件:当使用多进程或多线程加载数据时,如果出现了死锁或竞争条件,也可能导致数据加载器无法获取下一个数据批次。
worker_init_fn (callable, optional)
在启动新的子进程时执行的自定义的初始化函数。如果不为None,那么在设定随机种子后,数据加载前,这个函数会在每个子进程上调用,并把子进程id作为输入(一个在[0, num_workers - 1]的整数),默认为None。
generator (torch.Generator, 可选)
用于提供一个可替代的随机数生成器,用于对数据进行混洗(shuffle)操作。默认情况下,DataLoader使用PyTorch内置的随机数生成器torch.random进行混洗操作。但在某些情况下,您可能需要使用自定义的随机数生成器,此时则需要自定义的generator。
-
如果您同时设置了worker_init_fn参数,那么在每个工作进程中,worker_init_fn函数中设置的随机数生成器种子将优先于generator参数。这是因为每个工作进程都需要使用不同的随机种子,以避免产生相同的混洗顺序。
prefetch_factor (int, optional, keyword-only arg)
每个子进程预先加载的样本数,默认为2。例如prefetch_factor=2,num_workers=4,所有的进程则会预先加载2*4=8个样本。
persistent_workers (bool, optional)
默认情况下,当num_workers>0时,PyTorch会为每个epoch创建新的工作进程用于数据加载。这意味着在每个epoch开始时,PyTorch都会销毁上一个epoch中使用的工作进程,并创建新的工作进程。这种方式可以确保每个epoch中的数据顺序是独立的(在创建多线程),但同时也会带来一些开销,例如创建和销毁进程的时间开销。如果将persistent_workers设置为True,PyTorch将在第一个epoch时创建工作进程,然后在后续的epoch中重用这些工作进程,而不是在每个epoch开始时重新创建。这种方式可以提高数据加载的效率,因为它避免了频繁创建和销毁进程的开销。
code:torch/utils/data/dataloader.py#L346-L359
persistent_workers code:
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()