EPNet代码理解

 分享一下学习EPNet代码的理解,先看一下tools文件下的train文件,上面都是一个配置参数和划分数据batch、定义优化器、损失函数。就不管了直接看

if __name__ == "__main__":
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    print(cfg.TRAIN.RPN_TRAIN_WEIGHT, cfg.TRAIN.RCNN_TRAIN_WEIGHT)
    # input()

    cfg.TAG = os.path.splitext(os.path.basename(args.cfg_file))[0]

    if args.train_mode == 'rpn':
        cfg.RPN.ENABLED = True
        cfg.RCNN.ENABLED = False
        root_result_dir = os.path.join('../', 'output', 'rpn', cfg.TAG)
    elif args.train_mode == 'rcnn':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = cfg.RPN.FIXED = True
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
    elif args.train_mode == 'rcnn_online':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = True
        cfg.RPN.FIXED = False
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
    elif args.train_mode == 'rcnn_offline':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = False
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
    else:
        raise NotImplementedError

    if args.output_dir is not None:
        root_result_dir = args.output_dir
    os.makedirs(root_result_dir, exist_ok = True)

    log_file = os.path.join(root_result_dir, 'log_train.txt')
    logger = create_logger(log_file)
    logger.info('**********************Start logging**********************')

    # log to file
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))

    save_config_to_file(cfg, logger = logger)

    # copy important files to backup
    backup_dir = os.path.join(root_result_dir, 'backup_files')
    os.makedirs(backup_dir, exist_ok = True)
    os.system('cp *.py %s/' % backup_dir)
    os.system('cp ../lib/ %s/' % backup_dir)
    os.system('cp ../tools %s/' % backup_dir)
    os.system('cp ../*.py %s/' % backup_dir)

    # tensorboard log
    print(root_result_dir)
    tb_log = SummaryWriter(logdir = os.path.join(root_result_dir, 'tensorboard'))

    # create dataloader & network & optimizer
    train_loader, test_loader = create_dataloader(logger)
    # model = PointRCNN(num_classes=train_loader.dataset.num_class, use_xyz=True, mode='TRAIN')
    fn_decorator = train_functions.model_joint_fn_decorator()

    model = PointRCNN(num_classes = train_loader.dataset.num_class, use_xyz = True, mode = 'TRAIN')

    optimizer = create_optimizer(model)

    if args.mgpus:
        model = nn.DataParallel(model)
    model.cuda()

    # load checkpoint if it is possible
    start_epoch = it = 0
    last_epoch = -1
    if args.ckpt is not None:
        pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
        it, start_epoch = train_utils.load_checkpoint(pure_model, optimizer, filename = args.ckpt, logger = logger)
        last_epoch = start_epoch + 1

    lr_scheduler, bnm_scheduler = create_scheduler(optimizer, total_steps = len(train_loader) * args.epochs,
                                                   last_epoch = last_epoch)

    if args.rpn_ckpt is not None:
        pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
        total_keys = pure_model.state_dict().keys().__len__()
        train_utils.load_part_ckpt(pure_model, filename = args.rpn_ckpt, logger = logger, total_keys = total_keys)

    if cfg.TRAIN.LR_WARMUP and cfg.TRAIN.OPTIMIZER != 'adam_onecycle':
        lr_warmup_scheduler = train_utils.CosineWarmupLR(optimizer, T_max = cfg.TRAIN.WARMUP_EPOCH * len(train_loader),
                                                         eta_min = cfg.TRAIN.WARMUP_MIN)
    else:
        lr_warmup_scheduler = None

    # start training
    logger.info('**********************Start training**********************')
    ckpt_dir = os.path.join(root_result_dir, 'ckpt')
    os.makedirs(ckpt_dir, exist_ok = True)
    trainer = train_utils.Trainer(
            model,
            # train_functions.model_joint_fn_decorator(),
            fn_decorator,
            optimizer,
            ckpt_dir = ckpt_dir,
            lr_scheduler = lr_scheduler,
            bnm_scheduler = bnm_scheduler,
            # model_fn_eval=train_functions.model_joint_fn_decorator(),
            model_fn_eval = fn_decorator,
            tb_log = tb_log,
            eval_frequency = 1,
            lr_warmup_scheduler = lr_warmup_scheduler,
            warmup_epoch = cfg.TRAIN.WARMUP_EPOCH,
            grad_norm_clip = cfg.TRAIN.GRAD_NORM_CLIP
    )

    trainer.train(
            it,
            start_epoch,
            args.epochs,
            train_loader,
            test_loader,
            ckpt_save_interval = args.ckpt_save_interval,
            lr_scheduler_each_iter = (cfg.TRAIN.OPTIMIZER == 'adam_onecycle')
    )

    logger.info('**********************End training**********************')

 这一部分是运行结果的输出不用管。

 model = PointRCNN(num_classes = train_loader.dataset.num_class, use_xyz = True, mode = 'TRAIN')

这一部分是调用EPNet网络,作者是在PointRCNN基础上改的。

进到这个文件里面

class PointRCNN(nn.Module):
    def __init__(self, num_classes, use_xyz = True, mode = 'TRAIN'):
        super().__init__()

        assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED

        if cfg.RPN.ENABLED:
            self.rpn = RPN(use_xyz = use_xyz, mode = mode)

        if cfg.RCNN.ENABLED:
            rcnn_input_channels = 128  # channels of rpn features
            if cfg.RCNN.BACKBONE == 'pointnet':
                self.rcnn_net = RCNNNet(num_classes = num_classes, input_channels = rcnn_input_channels,
                                        use_xyz = use_xyz)
            elif cfg.RCNN.BACKBONE == 'pointsift':
                pass
            else:
                raise NotImplementedError

    def forward(self, input_data):

        if cfg.RPN.ENABLED:
            output = { }
            # rpn inference
            with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):##判断
                if cfg.RPN.FIXED:
                    self.rpn.eval()
                rpn_output = self.rpn(input_data)

                output.update(rpn_output)
                backbone_xyz = rpn_output['backbone_xyz']
                backbone_features = rpn_output['backbone_features']

            # rcnn inference
            if cfg.RCNN.ENABLED:
                with torch.no_grad():
                    ##在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。
                    rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output['rpn_reg']

                    rpn_scores_raw = rpn_cls[:, :, 0]
                    rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
                    seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
                    pts_depth = torch.norm(backbone_xyz, p = 2, dim = 2)##求2范数

                    # proposal layer
                    rois, roi_scores_raw = self.rpn.proposal_layer(rpn_scores_raw, rpn_reg, backbone_xyz)  # (B, M, 7)

                    output['rois'] = rois
                    output['roi_scores_raw'] = roi_scores_raw
                    output['seg_result'] = seg_mask

                rcnn_input_info = { 'rpn_xyz'     : backbone_xyz,
                                    'rpn_features': backbone_features.permute((0, 2, 1)),
                                    'seg_mask'    : seg_mask,
                                    'roi_boxes3d' : rois,
                                    'pts_depth'   : pts_depth
                                    }
                if self.training:
                    rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']

                rcnn_output = self.rcnn_net(rcnn_input_info)
                ''''
                rcnn_input_info = { 'rpn_xyz'     : backbone_xyz,
                                    'rpn_features': backbone_features.permute((0, 2, 1)),
                                    'seg_mask'    : seg_mask,
                                    'roi_boxes3d' : rois,
                                    'pts_depth'   : pts_depth
                                    'gt_boxes3d'  :input_data['gt_boxes3d']
                                    }
                
    
                '''
                output.update(rcnn_output)

        elif cfg.RCNN.ENABLED:
            output = self.rcnn_net(input_data)
        else:
            raise NotImplementedError

        return output
最后输出得到的是一个包含候选框中心坐标、特征、回归掩码、roi_box、真实框的一个字典  
也就是第一阶段+第二阶段的总网络层。

其中

是第一阶段提取特征和生成候选框的网络。

这是第二部分细化候选框的代码

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值