单目标跟踪数据集sampler

以STARK中的sampler.py为例:

一、代码整体结构

代码整体结构如下,分为以下阶段:

# 1. 任选一个数据集,如Lasot; 
# 2. 从该数据集中任选一个视频序列,如airplane-5; 
# 3. 从该视频序列中任选一个base frame; 
# 4. 分别从[base_frame_id - max_gap, base_frame_id]和(base_frame_id, base_frame_id +             max_gap]中采样一系列train frames和test frames, 需要注意的是只有visible的帧才会被采样;
# 5. 如果没有找到满足4中的这些帧,逐渐增加max_gap直到这些帧都找到;
# 6. 找到这些帧后对其进行预处理。
class TrackingSampler(torch.utils.data.Dataset):
    # 采样帧来形成一个一个batch。
    # 采样过程分为以下几个阶段:
    # 1. 任选一个数据集,如Lasot; 
    # 2. 从该数据集中任选一个视频序列,如airplane-5; 
    # 3. 从该视频序列中任选一个base frame; 
    # 4. 分别从[base_frame_id - max_gap, base_frame_id]和(base_frame_id, base_frame_id +             max_gap]中采样一系列train frames和test frames, 需要注意的是只有visible的帧才会被采样;
    # 5. 如果没有找到满足4中的这些帧,逐渐增加max_gap直到这些帧都找到;
    # 6. 找到这些帧后对其进行预处理。

    def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
                 num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
                 train_cls=False, pos_prob=0.5):

    def __len__(self):

    def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
                            allow_invisible=False, force_invisible=False):
        # 在min_id和max_id之间采样num_ids个visible的帧。


    def __getitem__(self, index):
        if self.train_cls:
            return self.getitem_cls() # 增加时序信息时使用这个,即score head的训练,用于更新模板
        else:
            return self.getitem() # 无时序信息时用这个

    def getitem(self):
        # 下一模块详细阐述

    def getitem_cls(self):
        # 下一模块详细阐述

    def get_center_box(self, H, W, ratio=1/8):

    def sample_seq_from_dataset(self, dataset, is_video_dataset):
        # 采样一个有足够visible frames的视频序列

    def get_one_search(self):

    def get_frame_ids_trident(self, visible):

    def get_frame_ids_stark(self, visible, valid):

二、 def getitem_cls(self)

sampler.py中getitem_cls(self)函数的解释如下:

# 1. 任选一个数据集(以概率p_datasets选取)
# 2. 从该数据集中采样有足够visible frames的视频序列
# 3. 在该视频序列中确定template_frame_ids和search_frame_ids(要求,visible且间距不超过max_gap)
# 4. 根据template_frame_ids获取templates的图像和bbox,valid,visible,class等信息
# 5. 获取正样本/负样本的search用于分类。正样本直接将第3步得到的search_frame_ids作为search,此时label=1; 负样本则重新任选数据集、视频序列、visible frame作为search,此时label=0。
# 6. 数据增强
def getitem_cls(self):
    # 获取用于分类的数据。在STARK中即用于score head的训练

    valid = False
    label = None
    while not valid:
        # 1. 任选一个数据集(以概率p_datasets选取)
        dataset = random.choices(self.datasets, self.p_datasets)[0] # lasot
        is_video_dataset = dataset.is_video_sequence() # Ture

        # 2. 从该数据集中采样有足够visible frames的视频序列
        seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset) 
        # seq_id=4, visible全为1的tensor,seq_info_dict字典是该序列所有帧的gt信息{'bbox','valid','visible'}
        
        # 3. 在该视频序列中确定template_frame_ids和search_frame_ids(要求,visible且间距不超过max_gap)
        if is_video_dataset:
            if self.frame_sample_mode in ["trident", "trident_pro"]:
                # 这里self.frame_sample_mode是"trident_pro"
                template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible) 
                # 采样得到间距不超过max_gap的三个visible frames,如template_frame_ids=[4159,4078],search_frame_ids=[4250]

            elif self.frame_sample_mode == "stark":
                template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
            else:
                raise ValueError("illegal frame sample mode")
        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            template_frame_ids = [1] * self.num_template_frames
            search_frame_ids = [1] * self.num_search_frames


        # 根据template_frame_ids和search_frame_ids找到其images和bbox
        try:
            # "try" is used to handle trackingnet data failure

             # 4. 根据template_frame_ids获取templates的图像和bbox,valid,visible,class等信息
            template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
                                                                         seq_info_dict)
           

            H, W, _ = template_frames[0].shape
            template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames # 因为lasot中没有mask注释,所以template_masks全为0

            # 5. 获取正样本/负样本的search用于分类。正样本直接将第3步得到的search_frame_ids作为search,此时label=1; 负样本则重新任选数据集、视频序列、visible frame作为search,此时label=0。

            # positive samples
            if random.random() < self.pos_prob:
                label = torch.ones(1,)
                search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
                search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
                    (H, W))] * self.num_search_frames


            # negative samples
            else:
                label = torch.zeros(1,)
                if is_video_dataset:
                    search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
                    if search_frame_ids is None:
                        # 任选一个数据集、视频序列、任选一个visible frame作为search
                        search_frames, search_anno, meta_obj_test = self.get_one_search() 
                    else:
                        search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
                                                                                       seq_info_dict)
                        search_anno["bbox"] = [self.get_center_box(H, W)]
                else:
                    search_frames, search_anno, meta_obj_test = self.get_one_search()
                H, W, _ = search_frames[0].shape
                search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
                    (H, W))] * self.num_search_frames

            data = TensorDict({'template_images': template_frames,
                               'template_anno': template_anno['bbox'],
                               'template_masks': template_masks,
                               'search_images': search_frames,
                               'search_anno': search_anno['bbox'],
                               'search_masks': search_masks,
                               'dataset': dataset.get_name(),
                               'test_class': meta_obj_test.get('object_class_name')})

            # 数据增强
            data = self.processing(data)
            # add classification label
            data["label"] = label # 0为负样本
            # check whether data is valid
            valid = data['valid']
        except:
            valid = False

    return data

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值