ATOM代码调试

一、运行环境

参考PyTracking代码中的INSTALL.md或者INSTALL_win.md进行环境配置,后者用于在windows环境下配置运行环境,不建议在windows下配置

一般情况下能顺利配置,Linux下需要一些root权限,同时源代码还需要加入pr pooling的模块,加入方法在源代码给出的文件中有说明

源代码地址地址

Pr-Pooling地址地址

二、数据集

ATOM的训练需要LaSOTGOT-10KTrackingNetCOCO数据集,需要提前准备

三、训练流程—数据集

./ltr/run_training.py进入

直接跳到./ltr/train_settings/bbreg/atom.py开始

首先是对训练流程进行配置

settings.description = 'ATOM IoUNet with default settings, but additionally using GOT10k for training.'
settings.batch_size = 64
settings.num_workers = 8
settings.print_interval = 1
settings.normalize_mean = [0.485, 0.456, 0.406]
settings.normalize_std = [0.229, 0.224, 0.225]
settings.search_area_factor = 5.0
settings.feature_sz = 18
settings.output_sz = settings.feature_sz * 16
settings.center_jitter_factor = {
   'train': 0, 'test': 4.5}
settings.scale_jitter_factor = {
   'train': 0, 'test': 0.5}

然后是训练及验证数据的初始化

# train
lasot_train = Lasot(settings.env.lasot_dir, split='train')
got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
coco_train = MSCOCOSeq(settings.env.coco_dir)
# val
got10k_val = Got10k(settings.env.got10k_dir, split='votval')

主要目的是为了得到各个数据集的信息,均基于类BaseVideoDataset,而BaseVideoDataset继承于Dataset

该部分其实是为5.15.2的数据采样器准备的

3.1 LaSot

# ./ltr/dataset/lasot.py
class Lasot(BaseVideoDataset):
    def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
    def _build_sequence_list(self, vid_ids=None, split=None):
    def _build_class_list(self):
    def get_name(self):
        # 返回数据集名称
    def has_class_info(self):
        # return true
    def has_occlusion_info(self):
        # return true
    def get_num_sequences(self):
        # 认为效果等同于__len__,返回训练集视频数量
    def get_num_classes(self):
        # 返回数据集class的数量
    def get_sequences_in_class(self, class_name):
        # 返回属于该类的视频序列列表
    def _read_bb_anno(self, seq_path):
        # 获取gt box
    def _read_target_visible(self, seq_path):
        # 获取含有目标的视频帧
    def _get_sequence_path(self, seq_id):
        # 获取序列的地址
    def get_sequence_info(self, seq_id):
        # 会返回一个字典{'bbox': bbox, 'valid': valid, 'visible': visible},元素好像是张量,大概意思应该是返回有效的视频帧
    def _get_frame_path(self, seq_path, frame_id):
    def _get_frame(self, seq_path, frame_id):
    def _get_class(self, seq_path):
    def get_class_name(self, seq_id):
    def get_frames(self, seq_id, frame_ids, anno=None):
        '''
        按照输入的视频id以及要获取的帧id列表,由该函数返回对应的frames
        输出如下:
        	frame_list: 帧所组成的列表
        	anno_frames: 包含了bounding box的一些注释信息
        	object_meta: 这个视频序列的一些信息,一个字典
        '''

初始化信息

# LaSot数据集包含了70个类,1120个视频(每个类16个视频)
# self.class_list中包含了所有类的类名
# self.class_to_id则是以字典的形式将所有的类用一个唯一的id(数字)表示
# self.sequence_list中包含了所有的视频序列名称
# self.seq_per_class将所有的类所属的序列进行了归纳 (此处可以利用,需先了解序列获取机理)

3.2 GOT-10k

# ./ltr/dataset/got10k.py
class Got10k(BaseVideoDataset):
    def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
    def get_name(self):
    def has_class_info(self):
    def has_occlusion_info(self):
    def _load_meta_info(self):
    def _read_meta(self, seq_path):
    def _build_seq_per_class(self):
    def get_sequences_in_class(self, class_name):
    def _get_sequence_list(self):
    def _read_bb_anno(self, seq_path):
    def _read_target_visible(self, seq_path):
    def _get_sequence_path(self, seq_id):
    def get_sequence_info(self, seq_id):
    def _get_frame_path(self, seq_path, frame_id):
    def _get_frame(self, seq_path, frame_id):
    def get_class_name(self, seq_id):
    def get_frames(self, seq_id, frame_ids, anno=None):

初始化信息

# GOT-10k vottrain包含460个目标类别
# self.sequence_list将会包含全部序列名称 (9335个视频,经过split之后会变为7086个,但是每个类别的视频数量并不一致,最少为1,最多为1744)
# self.sequence_meta_info会以字典形式包含视频序列的一些属性 [object_class_name, motion_class, major_class, root_class, motion_adverb]
# self.seq_per_class将所有的类所属的序列进行了归纳
# self.class_list包含了所有类的类名

3.3 TrackingNet

# ./ltr/dataset.tracking_net.py
class TrackingNet(BaseVideoDataset):
    def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
    def _load_class_info(self):
    def get_name(self):
    def has_class_info(self):
    def get_sequences_in_class(self, class_name):
    def _read_bb_anno(self, seq_id):
    def get_sequence_info(self, seq_id):
    def _get_frame(self, seq_id, frame_id):
    def _get_class(self, seq_id):
    def get_class_name(self, seq_id):
    def get_frames(self, seq_id, frame_ids, anno=None):

初始化信息

# TrackingNet(0-3) 包含有21个目标种类,10044个视频序列,也存在分布不均衡,但是没有GOT-10K这么极端,视频基数较大
# self.sequence_list包含有所有的视频序列信息,以元组形式 (数据集编号,序列名称)
# self.seq_to_class_map是一个视频序列和对应的类别之间的一个映射表
# self.seq_per_class将所有的类所属的序列进行了归纳
# self.class_list包含了所有类的类名

3.4 COCO

# ./ltr/dataset/coco_seq.py
class MSCOCOSeq(BaseVideoDataset):
    def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2017"):
    def _get_sequence_list(self):
    def is_video_sequence(self):
    def get_num_classes(self):
    def get_name(self):
    def has_class_info(self):
    def get_class_list(self):
    def has_segmentation_info(self):
    def get_num_sequences(self):
    def _build_seq_per_class(self):
    def get_sequences_in_class(self, class_name):
    def get_sequence_info(self, seq_id):
    def _get_anno(self, seq_id):
    def _get_frames(self, seq_id):
    def get_meta_info(self, seq_id):
    def get_class_name(self, seq_id):
    def get_frames(self, seq_id=None, frame_ids=None, anno=None):

初始化信息

# COCO数据集包含80个类
# self.cats表明了每一个图像的对应信息
# self.class_list包含了所有类的名称
# self.sequence_list包含了所有图像id
# self.seq_per_class将所有的类所属的序列(单张图像)进行了归纳

四、训练流程—数据集处理

4.1 Transform模块

代码作者使用的Transfrom模块是由自己重写的

这个Transform功能相较于原本torchvision中的版本,不仅能对图像进行统一的Transform,还能对边界框(bbox)以及分割掩码(mask)进行同步的Transform,这在之前的torchvision.transforms中是无法做到的

通过设置输入的参数,可以使图像连同边界框以及掩码一起变换

# Phase 1
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

上为atom.py中使用的一个Transform的代码

其中包含两个部分,首先是tfm.ToCrayscale,类形式如下

# ./ltr/data/transforms.py
# 注释中说明该类的作用是按照概率将图像转变为灰度图
class ToGrayscale(TransformBase):
    def __init__
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值