maskrcnn_benchmark源码解析——dataloader篇(中)

上篇主要讲了dataloader的常规内容,数据增强,合并等操作,本篇开始讲data sampler的内容


承接上文从data/build.py的后半段开始

    data_loaders = []
    for dataset in datasets:#看似用了遍历,其实datasets只有一个,前面都合并成一个了,datasets=[dataset]
    	#但是如果是test,则没有做合并操作,datasets真是多个dataset的list
    	'''
    	#data/build.py build_dataset()
    	# for testing, return a list of datasets
    	if not is_train:
        	return datasets
        	
       	# for training, concatenate all datasets into a single one
    	dataset = datasets[0]
   			if len(datasets) > 1:
        dataset = D.ConcatDataset(datasets)#连接list里的所有数据集
        
        return [dataset]
    	'''
        sampler = make_data_sampler(dataset, shuffle, is_distributed)#主要讲这里和下面这两个sampler
        batch_sampler = make_batch_data_sampler(
            dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, is_train, batch_stitch=batch_stitch
        )
        if not is_train:
            #NOTE: original for test
            collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
                BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
            #NOTE: using self-defined BatchCollatorSynthesize() for train
        else:
            collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
                BatchCollatorSynthesize(cfg.DATALOADER.SIZE_DIVISIBILITY)
        num_workers = cfg.DATALOADER.NUM_WORKERS
        if isinstance(batch_sampler, list):
            data_loader = [torch.utils.data.DataLoader(
                dataset,
                num_workers=num_workers,
                batch_sampler=batch_sampler[0],#第一个batch_sampler
                collate_fn=collator,
            )]

            data_loader.append(torch.utils.data.DataLoader(#双data_loader的第二个
                dataset,
                num_workers=num_workers,
                batch_sampler=batch_sampler[1],#第二个batch_sampler
                collate_fn=collator))
        else:
            data_loader = torch.utils.data.DataLoader(
                dataset,
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                collate_fn=collator,
            )
        data_loaders.append(data_loader)#多个dataset才用得到data_loader list,而且data_loader自己就是list,包含两个data_loader
    if is_train:
        # during training, a single (possibly concatenated) data_loader is returned
        assert len(data_loaders) == 1
        return data_loaders[0]
    return data_loaders

从上面可以看到,通过两个不同的batch_sampler,定义了两个data_loader,做成一个list,作为一个dataset的data loader。所以,接下来讲这两个batch_sampler。
首先从上面就可以看到一种包含关系:
sampler->batch_sampler->data_loader
data_loader=dataset+batch_sampler+collate_fn+num_works
其中,dataset负责给数据,batch_sample决定dataset如何采样数据,collate_fn决定怎样把多张图片组成一个batch
(1)sampler = make_data_sampler(dataset, shuffle, is_distributed)

def make_data_sampler(dataset, shuffle, distributed):
    if distributed:
        return samplers.DistributedSampler(dataset, shuffle=shuffle)
    if shuffle:
        sampler = torch.utils.data.sampler.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.sampler.SequentialSampler(dataset)
    return sampler

sample定义从dataset中采样每个数据的策略
torch.utils.data.sampler.RandomSampler:随机采样
torch.utils.data.sampler.SequentialSampler:按顺序采样
(2)batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, is_train, batch_stitch=batch_stitch
)

def make_batch_data_sampler(
    dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, is_train=True, batch_stitch=False,
):
    if aspect_grouping:#如果按长宽比分组
        if not isinstance(aspect_grouping, (list, tuple)):
            aspect_grouping = [aspect_grouping]
        aspect_ratios = _compute_aspect_ratios(dataset)#获取所有图片的长宽比list
        group_ids = _quantize(aspect_ratios, aspect_grouping)#将长宽比list转化为group_id list。[0,1,1,0,1,0,0,0,1]这样的
        if not is_train:
        	#注意这里是samplers,是from . import samplers
            batch_sampler = samplers.GroupedBatchSampler(#把sample组成batch sampler
                sampler, group_ids, images_per_batch, drop_uneven=False
            )
        else:
            #NOTE: for slicing training
            batch_sampler = samplers.GroupedBatchSampler(
                sampler, group_ids, images_per_batch, drop_uneven=True
            )
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, images_per_batch, drop_last=False
        )
    if num_iters is not None:
        if is_train:
            if batch_stitch:
                return samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter, False)
            return [samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter, is_regular_epoch) for is_regular_epoch in [True, False]]
        batch_sampler = samplers.IterationBasedBatchSampler(
            batch_sampler, num_iters, start_iter
        )
    return batch_sampler

可以看到,这个函数把sampler变成了batch_sampler
sampler, batch_sampler, collate_fn, dataset, dataloader的关系
就这里来说,sampler每次只采样一个样本,batch_sampler把sampler采样的样本打包成batch返回

samplers.GroupedBatchSampler是重写的batch_sampler

class GroupedBatchSampler(BatchSampler):
    """
    Wraps another sampler to yield a mini-batch of indices.
    It enforces that elements from the same group should appear in groups of batch_size.
    It also tries to provide mini-batches which follows an ordering which is
    as close as possible to the ordering from the original sampler.

    Arguments:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_uneven (bool): If ``True``, the sampler will drop the batches whose
            size is less than ``batch_size``

    """

    def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
        if not isinstance(sampler, Sampler):
            raise ValueError(
                "sampler should be an instance of "
                "torch.utils.data.Sampler, but got sampler={}".format(sampler)
            )
        self.sampler = sampler
        self.group_ids = torch.as_tensor(group_ids)#所有样本的group id list
        assert self.group_ids.dim() == 1
        self.batch_size = batch_size
        #NOTE: synthesized need batch_size
        self.syn_need_batch_size = int(batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)#NUM_IMAGES_STITCH=4
        #合成需要的batch size
        self.drop_uneven = drop_uneven

        self.groups = torch.unique(self.group_ids).sort(0)[0]#self.groups=tensor([0, 1])

        self._can_reuse_batches = False

    def _prepare_batches(self):
        dataset_size = len(self.group_ids)
        # get the sampled indices from the sampler
        sampled_ids = torch.as_tensor(list(self.sampler))#每个样本的index,比如[3,9,1,4,10]
        #由于self.sampler是迭代器,list(sampler)直接调用sampler的__iter__,把所有sample的元素转为list
        # potentially not all elements of the dataset were sampled
        # by the sampler (e.g., DistributedSampler).
        # construct a tensor which contains -1 if the element was
        # not sampled, and a non-negative number indicating the
        # order where the element was sampled.
        # for example. if sampled_ids = [3, 1] and dataset_size = 5,
        # the order is [-1, 1, -1, 0, -1]
        order = torch.full((dataset_size,), -1, dtype=torch.int64)#dataset_size个-1组成一个tensor
        order[sampled_ids] = torch.arange(len(sampled_ids))#把采样的样本的index转换为每个样本对应的顺序
        #比如[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]->[-1,2,-1,0,3,-1,-1,-1,-1,1,4]
		#[3,9,1,4,10](sampler_ids)对应[0,1,2,3,4](torch.arange(len(sampled_ids)))
        # get a mask with the elements that were sampled
        mask = order >= 0#(bool mask)

        # find the elements that belong to each individual cluster
        clusters = [(self.group_ids == i) & mask for i in self.groups]#group_ids
        # get relative order of the elements inside each cluster
        # that follows the order from the sampler
        relative_order = [order[cluster] for cluster in clusters]
        # with the relative order, find the absolute order in the
        # sampled space
        permutation_ids = [s[s.sort()[1]] for s in relative_order]
        # permute each cluster so that they follow the order from
        # the sampler
        permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]

        # splits each cluster in batch_size, and merge as a list of tensors
        #splits = [c.split(self.batch_size) for c in permuted_clusters]
        splits = [c.split(self.syn_need_batch_size) for c in permuted_clusters]
        merged = tuple(itertools.chain.from_iterable(splits))

        # now each batch internally has the right order, but
        # they are grouped by clusters. Find the permutation between
        # different batches that brings them as close as possible to
        # the order that we have in the sampler. For that, we will consider the
        # ordering as coming from the first element of each batch, and sort
        # correspondingly
        first_element_of_batch = [t[0].item() for t in merged]
        # get and inverse mapping from sampled indices and the position where
        # they occur (as returned by the sampler)
        inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
        # from the first element in each batch, get a relative ordering
        first_index_of_batch = torch.as_tensor(
            [inv_sampled_ids_map[s] for s in first_element_of_batch]
        )

        # permute the batches so that they approximately follow the order
        # from the sampler
        permutation_order = first_index_of_batch.sort(0)[1].tolist()
        # finally, permute the batches
        batches = [merged[i].tolist() for i in permutation_order]

        if self.drop_uneven:
            kept = []
            for batch in batches:
                #if len(batch) == self.batch_size:
                if len(batch) == self.syn_need_batch_size:
                    kept.append(batch)
            batches = kept
        return batches

    def __iter__(self):
        if self._can_reuse_batches:
            batches = self._batches
            self._can_reuse_batches = False
        else:
            batches = self._prepare_batches()
        self._batches = batches
        return iter(batches)

    def __len__(self):
        if not hasattr(self, "_batches"):
            self._batches = self._prepare_batches()
            self._can_reuse_batches = True
        return len(self._batches)

    #NOTE: slice training need
    def __call__(self, is_regular_epoch):#看这里看这里!!!
        self.syn_need_batch_size = self.batch_size
        if not is_regular_epoch:
            self.syn_need_batch_size = int(self.syn_need_batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)#再乘上4?
        return self

前面都是Group batch sampler的常规操作,关键看最后__call__,is_regular_epoch决定是用普通的batch还是用stitcher的batch,
如果是stitcher,则syn_need_batch_size变为batch_size的4倍,注意__call__里首先把self.syn_need_batch_size = self.batch_size,而不是使用的__init__里的self.syn_need_batch_size
问题:
为什么__init__里要有一个self.syn_need_batch_size = int(batch_size * cfg.STITCHER.NUM_IMAGES_STITCH)?这个self.syn_need_batch_size 只在__iter__的时候才用到。

看完了batch sampler的代码,发现batch sampler里面就仅仅是根据is_regular_epoch改了一下batch_size大小而已。。

真正做图像拼接操作的是在collate_fn里面,也就是collator。在sampler里改batch_size,而不同的batch_size。
Dataloader的参数batch_size传进去就是为了交给Batch_sampler的
collate_fn的内容留到下篇再讲了

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值