前言
本篇将介绍mmdetection如何构建dataloader类的。dataloader主要控制数据集的迭代读取。与之配套的是首先实现dataset类。关于dataset类的实现请转mmdetection之dataset类构建。
1、总体流程
在pytorch中,Dataloader实例构建需要以下重要参数(截取dataloader源码)。
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
简单介绍下各个参数含义:
dataset:就是继承Dataset类的实例;
batch_size: 批次大小
shuffle: True.在开始新的一轮epoch时,是否会重新打乱数据
sampler:迭代器:里面存储着数据集的下标(可能被打乱/顺序)。是迭代器。
batch_samper: 迭代sampler中下标,然后根据下标去dataset中取出batch_size个数据。
collate_fn:将batch个数据整合进一个list,调整宽和高。
可能上面几个参数定义有点而蒙。没关系,只需记住dataset,sampler,batch_sampler,dataloader均是迭代器即可。至于迭代器:理解为可以被 for … in dataset:使用即可。
既然Dataloader主要参数有了,那么现在看下mmdetection中如何build_dataloader的。接下来我打算分两部分进行讲解:
(1)如何实例化一个dataloader对象。如下图所示:mmdetection中主要实现下边四个参数。GroupSamper继承自torch的sampler类。shuffle大多数都是True。而batch_sampler参数mmdetection使用是pytorch中已实现的BatchSampler类。
(2)读取一个batch数据流程。
2、实例化dataloader
2.1. GroupSampler类实现
dataset的实现请转dataset类构建。这里我贴下GroupSampler源码:
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64) #
self.group_sizes = np.bincount(self.flag) # np.bincount()函数统计 下标01出现的次数。
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes): # self.group_sizes = [942,4096] ;其中942代表长度比例<1的图像数量;
if size == 0:
continue
indice = np.where(self.flag == i)[0] # 提取出self.flag中等于当前i的下标。 self.flag顺序存储着训练集中所有图像的aspect-ratio
assert len(indice) == size
np.random.shuffle(indice) # 这里将下标打乱了
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
indice = np.concatenate(
[indice, np.random.choice(indice, num_extra)])
indices.append(indice)
indices = np.concatenate(indices) # 合并陈一个list,长度为5011的
indices = [ # 按照batch将list划分:若batch=1,则将列表划分成长度为5011的数组。
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
其实主要实现了__iter__方法使其成为一个迭代器。而大致思路就是:假如我有一个5000张图像的数据集。那么数据集下标是0~4999.通过np.random.shuffle打乱5000个下标。假如batch是2,则共得到2500对。将这2500对以数组形式存于indices这个list中。最终通过iter(indices)迭代。
2.2. BatchSampler类
这部分mmdetection使用的是pytorch源码。我贴下源码:
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
从源码可以看出:BatchSampler以sampler初始化的。同时也实现了__iter__方法,每迭代够一个batch,则借助生成器yield batch,即返回一个batch数据。
3、读取一个batch数据流程
这里我想用张图说明下:文字不易描述:
总结
本文主要介绍mmdetection如何通过实现dataset,sampler来构造一个Dataloader,另外,展示了dataloader内部是如何迭代每个批次数据的。