上篇主要讲了dataloader的常规内容,数据增强,合并等操作,本篇开始讲data sampler的内容
承接上文从data/build.py的后半段开始
data_loaders = []
for dataset in datasets:#看似用了遍历,其实datasets只有一个,前面都合并成一个了,datasets=[dataset]
#但是如果是test,则没有做合并操作,datasets真是多个dataset的list
'''
#data/build.py build_dataset()
# for testing, return a list of datasets
if not is_train:
return datasets
# for training, concatenate all datasets into a single one
dataset = datasets[0]
if len(datasets) > 1:
dataset = D.ConcatDataset(datasets)#连接list里的所有数据集
return [dataset]
'''
sampler = make_data_sampler(dataset, shuffle, is_distributed)#主要讲这里和下面这两个sampler
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, is_train, batch_stitch=batch_stitch
)
if not is_train:
#NOTE: original for test
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
#NOTE: using self-defined BatchCollatorSynthesize() for train
else:
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollatorSynthesize(cfg.DATALOADER.SIZE_DIVISIBILITY)
num_workers = cfg.DATALOADER.NUM_WORKERS
if isinstance(batch_sampler, list):
data_loader = [torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler[0],#第一个batch_sampler
collate_fn=collator,
)]
data_loader.append(torch.utils.data.DataLoader(#双data_loader的第二个
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler[1],#第二个batch_sampler
collate_fn=collator))
else:
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=collator,
)
data_loaders.append(data_loader)#多个dataset才用得到data_loader list,而且data_loader自己就是list,包含两个data_loader
if is_train:
# during training, a single (possibly concatenated) data_loader is returned
assert len(data_loaders) == 1
return data_loaders[0]
return data_loaders
从上面可以看到,通过两个不同的batch_sampler,定义了两个data_loader,做成一个list,作为一个dataset的data loader。所以,接下来讲这两个batch_sampler。
首先从上面就可以看到一种包含关系:
sampler->batch_sampler->data_loader
data_loader=dataset+batch_sampler+collate_fn+num_works
其中,dataset负责给数据,batch_sample决定dataset如何采样数据,collate_fn决定怎样把多张图片组成一个batch
(1)sampler = make_data_sampler(dataset, shuffle, is_distributed)
def make_data_sampler(dataset, shuffle, distributed):
if distributed:
return samplers.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
sample定义从dataset中采样每个数据的策略:
torch.utils.data.sampler.RandomSampler:随机采样
torch.utils.data.sampler.SequentialSampler:按顺序采样
(2)batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, is_train, batch_stitch=batch_stitch
)
def make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, is_train=True, batch_stitch=False,
):
if aspect_grouping:#如果按长宽比分组
if not isinstance(aspect_grouping, (list, tuple)):
aspect_grouping = [aspect_grouping]
aspect_ratios = _compute_aspect_ratios(dataset)#获取所有图片的长宽比list
group_ids = _quantize(aspect_ratios, aspect_grouping)#将长宽比list转化为group_id list。[0,1,1,0,1,0,0,0,1]这样的
if not is_train:
#注意这里是samplers,是from . import samplers
batch_sampler = samplers.GroupedBatchSampler(#把sample组成batch sampler
sampler, group_ids, images_per_batch, drop_uneven=False
)
else:
#NOTE: for slicing training
batch_sampler = samplers.GroupedBatchSampler(
sampler, group_ids, images_per_batch, drop_uneven=True
)
else:
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_batch, drop_last=False
)
if num_iters is not None:
if is_train:
if batch_stitch:
return samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter, False)
return [samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter, is_regular_epoch) for is_regular_epoch in [True, False]]
batch_sampler = samplers.IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
可以看到,这个函数把sampler变成了batch_sampler
见sampler, batch_sampler, collate_fn, dataset, dataloader的关系
就这里来说,sampler每次只采样一个样本,batch_sampler把sampler采样的样本打包成batch返回
samplers.GroupedBatchSampler是重写的batch_sampler
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that elements from the same group should appear in groups of batch_size.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_uneven (bool): If ``True``, the sampler will drop the batches whose
size is less than ``batch_size``
"""
def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = torch.as_tensor(group_ids)#所有样本的group id list
assert self.group_ids.dim() == 1
self.batch_size = batch_size
#NOTE: synthesized need batch_size
self.syn_need_batch_size = int(batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)#NUM_IMAGES_STITCH=4
#合成需要的batch size
self.drop_uneven = drop_uneven
self.groups = torch.unique(self.group_ids).sort(0)[0]#self.groups=tensor([0, 1])
self._can_reuse_batches = False
def _prepare_batches(self):
dataset_size = len(self.group_ids)
# get the sampled indices from the sampler
sampled_ids = torch.as_tensor(list(self.sampler))#每个样本的index,比如[3,9,1,4,10]
#由于self.sampler是迭代器,list(sampler)直接调用sampler的__iter__,把所有sample的元素转为list
# potentially not all elements of the dataset were sampled
# by the sampler (e.g., DistributedSampler).
# construct a tensor which contains -1 if the element was
# not sampled, and a non-negative number indicating the
# order where the element was sampled.
# for example. if sampled_ids = [3, 1] and dataset_size = 5,
# the order is [-1, 1, -1, 0, -1]
order = torch.full((dataset_size,), -1, dtype=torch.int64)#dataset_size个-1组成一个tensor
order[sampled_ids] = torch.arange(len(sampled_ids))#把采样的样本的index转换为每个样本对应的顺序
#比如[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]->[-1,2,-1,0,3,-1,-1,-1,-1,1,4]
#[3,9,1,4,10](sampler_ids)对应[0,1,2,3,4](torch.arange(len(sampled_ids)))
# get a mask with the elements that were sampled
mask = order >= 0#(bool mask)
# find the elements that belong to each individual cluster
clusters = [(self.group_ids == i) & mask for i in self.groups]#group_ids
# get relative order of the elements inside each cluster
# that follows the order from the sampler
relative_order = [order[cluster] for cluster in clusters]
# with the relative order, find the absolute order in the
# sampled space
permutation_ids = [s[s.sort()[1]] for s in relative_order]
# permute each cluster so that they follow the order from
# the sampler
permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
# splits each cluster in batch_size, and merge as a list of tensors
#splits = [c.split(self.batch_size) for c in permuted_clusters]
splits = [c.split(self.syn_need_batch_size) for c in permuted_clusters]
merged = tuple(itertools.chain.from_iterable(splits))
# now each batch internally has the right order, but
# they are grouped by clusters. Find the permutation between
# different batches that brings them as close as possible to
# the order that we have in the sampler. For that, we will consider the
# ordering as coming from the first element of each batch, and sort
# correspondingly
first_element_of_batch = [t[0].item() for t in merged]
# get and inverse mapping from sampled indices and the position where
# they occur (as returned by the sampler)
inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
# from the first element in each batch, get a relative ordering
first_index_of_batch = torch.as_tensor(
[inv_sampled_ids_map[s] for s in first_element_of_batch]
)
# permute the batches so that they approximately follow the order
# from the sampler
permutation_order = first_index_of_batch.sort(0)[1].tolist()
# finally, permute the batches
batches = [merged[i].tolist() for i in permutation_order]
if self.drop_uneven:
kept = []
for batch in batches:
#if len(batch) == self.batch_size:
if len(batch) == self.syn_need_batch_size:
kept.append(batch)
batches = kept
return batches
def __iter__(self):
if self._can_reuse_batches:
batches = self._batches
self._can_reuse_batches = False
else:
batches = self._prepare_batches()
self._batches = batches
return iter(batches)
def __len__(self):
if not hasattr(self, "_batches"):
self._batches = self._prepare_batches()
self._can_reuse_batches = True
return len(self._batches)
#NOTE: slice training need
def __call__(self, is_regular_epoch):#看这里看这里!!!
self.syn_need_batch_size = self.batch_size
if not is_regular_epoch:
self.syn_need_batch_size = int(self.syn_need_batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)#再乘上4?
return self
前面都是Group batch sampler的常规操作,关键看最后__call__,is_regular_epoch
决定是用普通的batch还是用stitcher的batch,
如果是stitcher,则syn_need_batch_size变为batch_size的4倍,注意__call__里首先把self.syn_need_batch_size = self.batch_size
,而不是使用的__init__里的self.syn_need_batch_size
问题:
为什么__init__里要有一个self.syn_need_batch_size = int(batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)
?这个self.syn_need_batch_size 只在__iter__的时候才用到。
看完了batch sampler的代码,发现batch sampler里面就仅仅是根据is_regular_epoch改了一下batch_size大小而已。。
真正做图像拼接操作的是在collate_fn里面,也就是collator。在sampler里改batch_size,而不同的batch_size。
Dataloader的参数batch_size传进去就是为了交给Batch_sampler的
collate_fn的内容留到下篇再讲了