mmMOT代码阅读笔记

1.mmMOT系统框架图2.我理解的框架图3.代码结构4.代码解读自上往下的顺序,从main文件开始,逐渐详细main.py程序主文件,用来训练和验证模型获取配置文件,追踪模型初始化global args, config, best_mota args = parser.parse_args() with open(args.config) as f: config = yaml.load(f, Loader=yaml.FullLoader)
摘要由CSDN通过智能技术生成

1.mmMOT系统框架图

在这里插入图片描述
在这里插入图片描述

2.我理解的框架图

在这里插入图片描述

3.代码结构

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fFm6DtRB-1600938557756)(pics/image-20200923223033058.png)]

4.代码解读

自上往下的顺序,从main文件开始,逐渐详细

main.py

程序主文件,用来训练和验证模型

  1. 获取配置文件,追踪模型初始化

    global args, config, best_mota
        args = parser.parse_args()
    
        with open(args.config) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    
        config = EasyDict(config['common'])
        config.save_path = os.path.dirname(args.config)
    
    model = build_model(config)
    optimizer = build_optim(model, config)
    criterion = build_criterion(config.loss)
    
    tracking_module = TrackingModule(model, optimizer, criterion,config.det_type)
    
  2. 数据加载(dataset文件夹)

    # Data loading code
        train_transform, valid_transform = build_augmentation(config.augmentation)
    
        # train
        train_dataset = build_dataset(
            config,
            set_source='train',
            evaluate=False,
            train_transform=train_transform)
        trainval_dataset = build_dataset(
            config,
            set_source='train',
            evaluate=True,
            valid_transform=valid_transform)
        val_dataset = build_dataset(
            config,
            set_source='val',
            evaluate=True,
            valid_transform=valid_transform)
    
        train_sampler = DistributedGivenIterationSampler(
            train_dataset,
            config.lr_scheduler.max_iter,
            config.batch_size,
            world_size=1,
            rank=0,
            last_iter=last_iter)
    
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.workers,
            pin_memory=True,
            sampler=train_sampler)
    
  3. 训练

    def train(train_loader, val_loader, trainval_loader, tracking_module, lr_scheduler, start_iter, tb_logger):
    
    # 前向传播
    # forward
    loss = tracking_module.step(input.squeeze(0), det_info, det_id, det_cls, det_split)
    
  4. 验证

    def validate(val_loader,
                 tracking_module,
                 step,
                 part='train',
                 fusion_list=None,
                 fuse_prob=False):
    
tracking_model.py

追踪模型主文件,模式转换、训练过程、预测过程

  1. 模式转换

    训练模式or评估模式

    #评估模式
    def eval(self):
        if isinstance(self.model, list):
            for i in range(len(self.model)):
              self.model[i].eval()
        else:
            self.model.eval()
        self.clear_mem()
        return
    
    #训练模式
    def train(self):
        if isinstance(self.model, list):
            for i in range(len(self.model)):
                self.model[i].train()
        else:
            self.model.train()
        self.clear_mem()
        return
    
  2. step(),训练函数,用于模型训练,得到4种score后计算损失并反向传播

    def step(self, det_img, det_info, det_id, det_cls, det_split):
        # model是tracking_net搭建起来的完整的训练模型,经过模型计算得到4种score
        det_score, link_score, new_score, end_score, trans = self.model(
            det_img, det_info, det_split)
        # generate gt_y
        gt_det, gt_link, gt_new, gt_end = self.generate_gt(
            det_score[0], det_cls, det_id, det_split)
    
        # calculate loss
        # loss值,由cost.py模块计算得到
        loss = self.criterion(det_split, gt_det, gt_link, gt_new, gt_end,
                              det_score, link_score, new_score, end_score,
                              trans)
    
        # 反向传播并优化
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
        return loss
    
  3. predict(),得到追踪结果,step()计算出来score之后直接反向传播,不必得到最终的追踪结果,predict根据得到的4种分数使用线性规划方法求解最终的结果

    def predict(self, det_imgs, det_info, dets, det_split):
        # model是tracking_net搭建起来的完整的训练模型,经过模型计算得到4种score
        det_score, link_score, new_score, end_score, _ = self.model(
            det_imgs, det_info, det_split)
    
        # ortools_solve是线性规划模块,在solvers.py模块中,用于根据4种score进行关联,得到追踪结果
        # 我觉得这一部分的内容也可以用贪婪匹配、匈牙利算法或者神经网络替代
        assign_det, assign_link, assign_new, assign_end = ortools_solve(
            det_score[self.test_mode],
            [link_score[0][self.test_mode:self.test_mode + 1]],
            new_score[self.test_mode], end_score[self.test_mode], det_split)
    
        assign_id, assign_bbox = self.assign_det_id(assign_det, assign_link,
                                                        assign_new, assign_end,
                                                        det_split, dets)
        aligned_ids, aligned_dets, frame_start = self.align_id(
            assign_id, assign_bbox)
    
        return aligned_ids, aligned_dets, frame_start
    
solvers.py

线性规划模块,以4种估计器估计得到的score作为输入,得到最终的关联结果

# 函数头,输入为估计器得到的估计结果,约束条件是当前帧的所有识别结果要么是新轨迹的开始,要么和之前的轨迹相连,或者要么是旧轨迹的结束,要么和之后的轨迹相连
def ortools_solve(det_score,
                  link_score,
                  new_score,
                  end_score,
                  det_split,
                  gt=None):
    ......

# 返回值为关联结果
	return assign_det, assign_link, assign_new, assign_end
cost.py

损失函数定义模块,损失函数计算公式为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WrNQt3hB-1600938557757)(pics/mmMOT Cost function.png)]

  • CostLoss类、NoDistanceLoss类、DistanceLoss类(我认为这几个类可能是用于做消融实验的);

  • DetLoss类,用于计算detection score和new score和end score的损失,可以使用binary cross entropy损失、l1损失、l2损失和ghm损失,其中ghm损失在module/ghm_loss中定义;

  • LinkLoss类,用于计算关联损失,可以使用l1损失或l2损失。

  • TrackingLoss类,用于计算完整框架的loss,用于模型训练时的反向传播。

# LinkLoss类,两类损失函数初始化
if 'l2' in loss_type:
    self.l2_loss = nn.MSELoss()
if 'l1' in loss_type:
    print("Use smooth l1 loss for link")
    self.l1_loss = nn.SmoothL1Loss()
    
# 计算LinkLoss
def forward(self, det_split, gt_det, link_score, gt_link):
    if 'l2' in self.loss_type:
        loss += self.l2_loss(link_score[i].mul(mask),gt_link[i].repeat(mask.size(0), 1, 1))
    if 'l1' in self.loss_type:
        loss += self.l1_loss(link_score[i].mul(mask),gt_link[i].repeat(mask.size(0), 1, 1))
    return loss
# DetLoss类,用于计算det loss、new loss、end loss,可以使用4种损失函数
def forward(self, det_score, gt_score):
    """

    :param det_score: 3xL
    :param gt_score: L
    :return: loss
    """
    gt_score = gt_score.unsqueeze(0).repeat(det_score.size(0), 1)
    if 'bce' in self.loss_type:
        loss = F.binary_cross_entropy_with_logits(det_score, gt_score)
    if 'l2' in self.loss_type:
        mask = 1 - gt_score.eq(self.ignore_index)
        loss = F.mse_loss(det_score.mul(mask.float()), gt_score)
    if 'l1' in self.loss_type:
        mask = 1 - gt_score.eq(self.ignore_index)
        loss = F.smoo
  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值