刚整理完config配置文件,这里顺便把数据加载整理一下,方便以后查阅。
源码分析
class MyIterable():
def __init__(self):
self.data = [1, 2, 3, 4]
def __iter__(self): # 返回一个迭代器
return MyIterator(self.data)
def __getitem__(self, idx): # 可以用[]对数组进行索引
return self.data[idx]
class MyIterator():
def __init__(self, data):
self.data = data
self.counter = 0
def __iter__(self):
return self
def __next__(self):
if self.counter >= len(self.data):
raise StopIteration()
data = self.data[self.counter]
self.counter += 1
return data
my_iterable = MyIterable()
for d in my_iterable:
print(d)
print(my_iterable[2])
输出:
Ddatset返回一张图
DataLoader返回一个batch的图
-
batch_size (int|None) - 每mini-batch中样本个数,为 batch_sampler 的替代参数,若 batch_sampler 未设置,会根据 batch_size shuffle drop_last 创建一个 paddle.io.BatchSampler 。默认值为1。
-
batch_sampler 设置从数据集中取数据的方式。默认值为None。
-
collate_fn (callable) - 通过此参数指定如果将样本列表组合为mini-batch数据,当 collate_fn 为None时,默认为将样本个字段在第0维上堆叠(同 np.stack(…, axis=0) )为mini-batch的数据。默认值为None。
class DataLoader(object):
def __init__(self, dataset, batch_sampler=None, batch_size=1, collate_fn=None, ...):
if batch_sampler is not None:
self.batch_sampler = batch_sampler
...
else:
...
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(dataset,
batch_size)
else:
self.batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
def __iter__(self):
if self.num_workers == 0:
return _DataLoaderIterSingleProcess(self) # 单进程取数据
else:
return _DataLoaderIterMultiProcess(self) # 多进程取数据
def __call__(self):
return self.__iter__()
def __len__(self):
...
class _MapDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last)
def fetch(self, batch_indices, done_event=None):
if self.auto_collate_batch:
data = []
for idx in batch_indices:
if done_event is None or not done_event.is_set():
data.append(self.dataset[idx]) # 关键代码,将DataLoader和Dataset连接起来
else:
return None
global _WARNING_TO_LOG
if not isinstance(data[0], (Sequence, Mapping)) \
and _WARNING_TO_LOG:
self._log_warning()
_WARNING_TO_LOG = False
else:
data = self.dataset[batch_indices]
if self.collate_fn:
data = self.collate_fn(data)
return data
- num_workers (int) - 用于加载数据的子进程个数,若为0即为不开启子进程,在主进程中进行数据加载。默认值为0。