pytorch数据载入Dataloader类
pytorch的数据导入一般均由Dataloader方法实现,但Dataloader含1000多行代码,阅读起来让人望而却步,pytorch官方文档对该部分所实现的功能有大致介绍,能够帮助你建立关于Dataloader功能的一个框架,文档为全英文,但值得阅读。
pytorch官方文档Dataloader
涉及迭代器与生成器的相关知识
下述代码为使用DataLoader载入数据的一般形式及DataLoader类的输入参数。
train_data_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=batch_size,
shuffle=True,
num_workers=nw,
collate_fn=train_data_set.collate_fn)
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: _collate_fn_t = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
在DataLoader类的众多参数中,可以看到有两种sampler:sampler和batch_sampler,都默认为None。sampler的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,按batch_size的大小划分为batch组的index。
sampler和batch_sampler均由其对应的类来实现,两个类的定义中均含有def __iter__(self):
,表明其为迭代器(迭代器的相关知识还不太了解,有错请指出),在调试过程中两个类的返回值均为地址值,未看到图片索引列表,如下图:
其中,sampler和batch_sampler的元素个数分别为5717、1430,此时batch_size设置为4,5717 / 4 = 1429...1
,证明batch_sampler是将sampler生成的indices,按照batch_size的大小划分为batch组的index。
下面这段代码是class BatchSampler(Sampler[List[int]]):
类中对batch划分的实现,通过yeild实现对batch的返回。yeils与生成器self.sample中存放图片的索引。
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
Dataloader对数据的处理主要就通过sampler和batch_sampler实现。
参考资料:
Pytorch的DataLoader, DataSet, Sampler之间的关系
Pytorch Sampler详解
Python迭代器