1.mmMOT系统框架图
2.我理解的框架图
3.代码结构
4.代码解读
自上往下的顺序,从main文件开始,逐渐详细
main.py
程序主文件,用来训练和验证模型
-
获取配置文件,追踪模型初始化
global args, config, best_mota args = parser.parse_args() with open(args.config) as f: config = yaml.load(f, Loader=yaml.FullLoader) config = EasyDict(config['common']) config.save_path = os.path.dirname(args.config)
model = build_model(config) optimizer = build_optim(model, config) criterion = build_criterion(config.loss) tracking_module = TrackingModule(model, optimizer, criterion,config.det_type)
-
数据加载(dataset文件夹)
# Data loading code train_transform, valid_transform = build_augmentation(config.augmentation) # train train_dataset = build_dataset( config, set_source='train', evaluate=False, train_transform=train_transform) trainval_dataset = build_dataset( config, set_source='train', evaluate=True, valid_transform=valid_transform) val_dataset = build_dataset( config, set_source='val', evaluate=True, valid_transform=valid_transform) train_sampler = DistributedGivenIterationSampler( train_dataset, config.lr_scheduler.max_iter, config.batch_size, world_size=1, rank=0, last_iter=last_iter) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True, sampler=train_sampler)
-
训练
def train(train_loader, val_loader, trainval_loader, tracking_module, lr_scheduler, start_iter, tb_logger):
# 前向传播 # forward loss = tracking_module.step(input.squeeze(0), det_info, det_id, det_cls, det_split)
-
验证
def validate(val_loader, tracking_module, step, part='train', fusion_list=None, fuse_prob=False):
tracking_model.py
追踪模型主文件,模式转换、训练过程、预测过程
-
模式转换
训练模式or评估模式
#评估模式 def eval(self): if isinstance(self.model, list): for i in range(len(self.model)): self.model[i].eval() else: self.model.eval() self.clear_mem() return #训练模式 def train(self): if isinstance(self.model, list): for i in range(len(self.model)): self.model[i].train() else: self.model.train() self.clear_mem() return
-
step(),训练函数,用于模型训练,得到4种score后计算损失并反向传播
def step(self, det_img, det_info, det_id, det_cls, det_split): # model是tracking_net搭建起来的完整的训练模型,经过模型计算得到4种score det_score, link_score, new_score, end_score, trans = self.model( det_img, det_info, det_split) # generate gt_y gt_det, gt_link, gt_new, gt_end = self.generate_gt( det_score[0], det_cls, det_id, det_split) # calculate loss # loss值,由cost.py模块计算得到 loss = self.criterion(det_split, gt_det, gt_link, gt_new, gt_end, det_score, link_score, new_score, end_score, trans) # 反向传播并优化 self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss
-
predict(),得到追踪结果,step()计算出来score之后直接反向传播,不必得到最终的追踪结果,predict根据得到的4种分数使用线性规划方法求解最终的结果
def predict(self, det_imgs, det_info, dets, det_split): # model是tracking_net搭建起来的完整的训练模型,经过模型计算得到4种score det_score, link_score, new_score, end_score, _ = self.model( det_imgs, det_info, det_split) # ortools_solve是线性规划模块,在solvers.py模块中,用于根据4种score进行关联,得到追踪结果 # 我觉得这一部分的内容也可以用贪婪匹配、匈牙利算法或者神经网络替代 assign_det, assign_link, assign_new, assign_end = ortools_solve( det_score[self.test_mode], [link_score[0][self.test_mode:self.test_mode + 1]], new_score[self.test_mode], end_score[self.test_mode], det_split) assign_id, assign_bbox = self.assign_det_id(assign_det, assign_link, assign_new, assign_end, det_split, dets) aligned_ids, aligned_dets, frame_start = self.align_id( assign_id, assign_bbox) return aligned_ids, aligned_dets, frame_start
solvers.py
线性规划模块,以4种估计器估计得到的score作为输入,得到最终的关联结果
# 函数头,输入为估计器得到的估计结果,约束条件是当前帧的所有识别结果要么是新轨迹的开始,要么和之前的轨迹相连,或者要么是旧轨迹的结束,要么和之后的轨迹相连
def ortools_solve(det_score,
link_score,
new_score,
end_score,
det_split,
gt=None):
......
# 返回值为关联结果
return assign_det, assign_link, assign_new, assign_end
cost.py
损失函数定义模块,损失函数计算公式为:
-
CostLoss类、NoDistanceLoss类、DistanceLoss类(我认为这几个类可能是用于做消融实验的);
-
DetLoss类,用于计算detection score和new score和end score的损失,可以使用binary cross entropy损失、l1损失、l2损失和ghm损失,其中ghm损失在module/ghm_loss中定义;
-
LinkLoss类,用于计算关联损失,可以使用l1损失或l2损失。
-
TrackingLoss类,用于计算完整框架的loss,用于模型训练时的反向传播。
# LinkLoss类,两类损失函数初始化
if 'l2' in loss_type:
self.l2_loss = nn.MSELoss()
if 'l1' in loss_type:
print("Use smooth l1 loss for link")
self.l1_loss = nn.SmoothL1Loss()
# 计算LinkLoss
def forward(self, det_split, gt_det, link_score, gt_link):
if 'l2' in self.loss_type:
loss += self.l2_loss(link_score[i].mul(mask),gt_link[i].repeat(mask.size(0), 1, 1))
if 'l1' in self.loss_type:
loss += self.l1_loss(link_score[i].mul(mask),gt_link[i].repeat(mask.size(0), 1, 1))
return loss
# DetLoss类,用于计算det loss、new loss、end loss,可以使用4种损失函数
def forward(self, det_score, gt_score):
"""
:param det_score: 3xL
:param gt_score: L
:return: loss
"""
gt_score = gt_score.unsqueeze(0).repeat(det_score.size(0), 1)
if 'bce' in self.loss_type:
loss = F.binary_cross_entropy_with_logits(det_score, gt_score)
if 'l2' in self.loss_type:
mask = 1 - gt_score.eq(self.ignore_index)
loss = F.mse_loss(det_score.mul(mask.float()), gt_score)
if 'l1' in self.loss_type:
mask = 1 - gt_score.eq(self.ignore_index)
loss = F