一、运行环境
参考PyTracking代码中的INSTALL.md或者INSTALL_win.md进行环境配置,后者用于在windows环境下配置运行环境,不建议在windows下配置
一般情况下能顺利配置,Linux下需要一些root权限,同时源代码还需要加入pr pooling的模块,加入方法在源代码给出的文件中有说明
源代码地址:地址
Pr-Pooling地址:地址
二、数据集
ATOM的训练需要LaSOT、GOT-10K、TrackingNet、COCO数据集,需要提前准备
三、训练流程—数据集
由./ltr/run_training.py进入
直接跳到./ltr/train_settings/bbreg/atom.py开始
首先是对训练流程进行配置
settings.description = 'ATOM IoUNet with default settings, but additionally using GOT10k for training.'
settings.batch_size = 64
settings.num_workers = 8
settings.print_interval = 1
settings.normalize_mean = [0.485, 0.456, 0.406]
settings.normalize_std = [0.229, 0.224, 0.225]
settings.search_area_factor = 5.0
settings.feature_sz = 18
settings.output_sz = settings.feature_sz * 16
settings.center_jitter_factor = {
'train': 0, 'test': 4.5}
settings.scale_jitter_factor = {
'train': 0, 'test': 0.5}
然后是训练及验证数据的初始化
# train
lasot_train = Lasot(settings.env.lasot_dir, split='train')
got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
coco_train = MSCOCOSeq(settings.env.coco_dir)
# val
got10k_val = Got10k(settings.env.got10k_dir, split='votval')
主要目的是为了得到各个数据集的信息,均基于类BaseVideoDataset,而BaseVideoDataset继承于Dataset
该部分其实是为5.1和5.2的数据采样器准备的
3.1 LaSot
# ./ltr/dataset/lasot.py
class Lasot(BaseVideoDataset):
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
def _build_sequence_list(self, vid_ids=None, split=None):
def _build_class_list(self):
def get_name(self):
# 返回数据集名称
def has_class_info(self):
# return true
def has_occlusion_info(self):
# return true
def get_num_sequences(self):
# 认为效果等同于__len__,返回训练集视频数量
def get_num_classes(self):
# 返回数据集class的数量
def get_sequences_in_class(self, class_name):
# 返回属于该类的视频序列列表
def _read_bb_anno(self, seq_path):
# 获取gt box
def _read_target_visible(self, seq_path):
# 获取含有目标的视频帧
def _get_sequence_path(self, seq_id):
# 获取序列的地址
def get_sequence_info(self, seq_id):
# 会返回一个字典{'bbox': bbox, 'valid': valid, 'visible': visible},元素好像是张量,大概意思应该是返回有效的视频帧
def _get_frame_path(self, seq_path, frame_id):
def _get_frame(self, seq_path, frame_id):
def _get_class(self, seq_path):
def get_class_name(self, seq_id):
def get_frames(self, seq_id, frame_ids, anno=None):
'''
按照输入的视频id以及要获取的帧id列表,由该函数返回对应的frames
输出如下:
frame_list: 帧所组成的列表
anno_frames: 包含了bounding box的一些注释信息
object_meta: 这个视频序列的一些信息,一个字典
'''
初始化信息
# LaSot数据集包含了70个类,1120个视频(每个类16个视频)
# self.class_list中包含了所有类的类名
# self.class_to_id则是以字典的形式将所有的类用一个唯一的id(数字)表示
# self.sequence_list中包含了所有的视频序列名称
# self.seq_per_class将所有的类所属的序列进行了归纳 (此处可以利用,需先了解序列获取机理)
3.2 GOT-10k
# ./ltr/dataset/got10k.py
class Got10k(BaseVideoDataset):
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
def get_name(self):
def has_class_info(self):
def has_occlusion_info(self):
def _load_meta_info(self):
def _read_meta(self, seq_path):
def _build_seq_per_class(self):
def get_sequences_in_class(self, class_name):
def _get_sequence_list(self):
def _read_bb_anno(self, seq_path):
def _read_target_visible(self, seq_path):
def _get_sequence_path(self, seq_id):
def get_sequence_info(self, seq_id):
def _get_frame_path(self, seq_path, frame_id):
def _get_frame(self, seq_path, frame_id):
def get_class_name(self, seq_id):
def get_frames(self, seq_id, frame_ids, anno=None):
初始化信息
# GOT-10k vottrain包含460个目标类别
# self.sequence_list将会包含全部序列名称 (9335个视频,经过split之后会变为7086个,但是每个类别的视频数量并不一致,最少为1,最多为1744)
# self.sequence_meta_info会以字典形式包含视频序列的一些属性 [object_class_name, motion_class, major_class, root_class, motion_adverb]
# self.seq_per_class将所有的类所属的序列进行了归纳
# self.class_list包含了所有类的类名
3.3 TrackingNet
# ./ltr/dataset.tracking_net.py
class TrackingNet(BaseVideoDataset):
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
def _load_class_info(self):
def get_name(self):
def has_class_info(self):
def get_sequences_in_class(self, class_name):
def _read_bb_anno(self, seq_id):
def get_sequence_info(self, seq_id):
def _get_frame(self, seq_id, frame_id):
def _get_class(self, seq_id):
def get_class_name(self, seq_id):
def get_frames(self, seq_id, frame_ids, anno=None):
初始化信息
# TrackingNet(0-3) 包含有21个目标种类,10044个视频序列,也存在分布不均衡,但是没有GOT-10K这么极端,视频基数较大
# self.sequence_list包含有所有的视频序列信息,以元组形式 (数据集编号,序列名称)
# self.seq_to_class_map是一个视频序列和对应的类别之间的一个映射表
# self.seq_per_class将所有的类所属的序列进行了归纳
# self.class_list包含了所有类的类名
3.4 COCO
# ./ltr/dataset/coco_seq.py
class MSCOCOSeq(BaseVideoDataset):
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2017"):
def _get_sequence_list(self):
def is_video_sequence(self):
def get_num_classes(self):
def get_name(self):
def has_class_info(self):
def get_class_list(self):
def has_segmentation_info(self):
def get_num_sequences(self):
def _build_seq_per_class(self):
def get_sequences_in_class(self, class_name):
def get_sequence_info(self, seq_id):
def _get_anno(self, seq_id):
def _get_frames(self, seq_id):
def get_meta_info(self, seq_id):
def get_class_name(self, seq_id):
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
初始化信息
# COCO数据集包含80个类
# self.cats表明了每一个图像的对应信息
# self.class_list包含了所有类的名称
# self.sequence_list包含了所有图像id
# self.seq_per_class将所有的类所属的序列(单张图像)进行了归纳
四、训练流程—数据集处理
4.1 Transform模块
代码作者使用的Transfrom模块是由自己重写的
这个Transform功能相较于原本torchvision中的版本,不仅能对图像进行统一的Transform,还能对边界框(bbox)以及分割掩码(mask)进行同步的Transform,这在之前的torchvision.transforms中是无法做到的
通过设置输入的参数,可以使图像连同边界框以及掩码一起变换
# Phase 1
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))
上为atom.py中使用的一个Transform的代码
其中包含两个部分,首先是tfm.ToCrayscale,类形式如下
# ./ltr/data/transforms.py
# 注释中说明该类的作用是按照概率将图像转变为灰度图
class ToGrayscale(TransformBase):
def __init__