trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
=======================================================================================================================
epoch = -1
num_tries = 1
for i in range(num_tries):
try:
if load_latest:
self.load_checkpoint()
for epoch in range(self.epoch+1, max_epochs+1):
self.epoch = epoch
self.train_epoch()
self.cycle_dataset(loader) LTRLoader:1875 , 315
-----------------------------------------------------------------------------------------------------------------------
for i, data in enumerate(loader, 1):
data:'template_image 2 32 3 128 128'图片数 batch_size 'template_anno 2 32 4'
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
while not valid:
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]#Lasot:1120
# sample a sequence from the given dataset
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
630 {
'bbox':[[482., 278., 23., 14.], ...,[596., 295., 27., 17.]]), 'valid': tensor([True, True, True, ..., True, True, True]), 'visible': tensor([1, 1, 1, ..., 1, 1, 1], dtype=torch.uint8)}
# sample template and search frame ids
if is_video_dataset:
if self.frame_sample_mode in ["trident", "trident_pro"]:
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
elif self.frame_sample_mode == "stark":
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
else:
raise ValueError("illegal frame sample mode")
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
try:
# "try" is used to handle trackingnet data failure
# get images and bounding boxes (for templates)
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
seq_info_dict)
H, W, _ = template_frames[0].shape
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
(H, W))] * self.num_template_frames
# get images and bounding boxes (for searches)
# positive samples
if random.random() < self.pos_prob:
label = torch.ones(1,)
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
# negative samples
else:
label = torch.zeros(1,)
if is_video_dataset:
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
if search_frame_ids is None:
search_frames, search_anno, meta_obj_test = self.get_one_search()
else:
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
seq_info_dict)
search_anno["bbox"] = [self.get_center_box(H, W)]
else:
search_frames, search_anno, meta_obj_test = self.get_one_search()
H, W, _ = search_frames[0].shape
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
data = TensorDict({
'template_images': template_frames, 2 2 3 128 128
'template_anno': template_anno['bbox'], 2 2 4
'template_masks': template_masks, 2 2 128
MixFormerOnlineScore train代码分析
于 2022-07-18 20:54:34 首次发布