问题描述
Pytorch的Dataloader一般会设置多个worker加载数据以提升训练速度。然而,可以发现每个epoch开始的时候数据加载的耗时总会高一截,具体表现如图:
问题原因
正如Pytorch Forum这个讨论一样,DataLoader
在每个epoch开始的时候都会重新创建一次,因此每个epocch开始所有的worker会重新开始prefetching过程,因此速度会变慢。
Solution
参考这里的代码, 可以用MultiEpochsDataLoader
代替原来的DataLoader
即可。
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
改进之后效果
Reference
相关讨论:
Pytorch讨论1
Pytorch讨论2
Github PR