分享一下学习EPNet代码的理解,先看一下tools文件下的train文件,上面都是一个配置参数和划分数据batch、定义优化器、损失函数。就不管了直接看
if __name__ == "__main__":
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print(cfg.TRAIN.RPN_TRAIN_WEIGHT, cfg.TRAIN.RCNN_TRAIN_WEIGHT)
# input()
cfg.TAG = os.path.splitext(os.path.basename(args.cfg_file))[0]
if args.train_mode == 'rpn':
cfg.RPN.ENABLED = True
cfg.RCNN.ENABLED = False
root_result_dir = os.path.join('../', 'output', 'rpn', cfg.TAG)
elif args.train_mode == 'rcnn':
cfg.RCNN.ENABLED = True
cfg.RPN.ENABLED = cfg.RPN.FIXED = True
root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
elif args.train_mode == 'rcnn_online':
cfg.RCNN.ENABLED = True
cfg.RPN.ENABLED = True
cfg.RPN.FIXED = False
root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
elif args.train_mode == 'rcnn_offline':
cfg.RCNN.ENABLED = True
cfg.RPN.ENABLED = False
root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
else:
raise NotImplementedError
if args.output_dir is not None:
root_result_dir = args.output_dir
os.makedirs(root_result_dir, exist_ok = True)
log_file = os.path.join(root_result_dir, 'log_train.txt')
logger = create_logger(log_file)
logger.info('**********************Start logging**********************')
# log to file
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)
for key, val in vars(args).items():
logger.info("{:16} {}".format(key, val))
save_config_to_file(cfg, logger = logger)
# copy important files to backup
backup_dir = os.path.join(root_result_dir, 'backup_files')
os.makedirs(backup_dir, exist_ok = True)
os.system('cp *.py %s/' % backup_dir)
os.system('cp ../lib/ %s/' % backup_dir)
os.system('cp ../tools %s/' % backup_dir)
os.system('cp ../*.py %s/' % backup_dir)
# tensorboard log
print(root_result_dir)
tb_log = SummaryWriter(logdir = os.path.join(root_result_dir, 'tensorboard'))
# create dataloader & network & optimizer
train_loader, test_loader = create_dataloader(logger)
# model = PointRCNN(num_classes=train_loader.dataset.num_class, use_xyz=True, mode='TRAIN')
fn_decorator = train_functions.model_joint_fn_decorator()
model = PointRCNN(num_classes = train_loader.dataset.num_class, use_xyz = True, mode = 'TRAIN')
optimizer = create_optimizer(model)
if args.mgpus:
model = nn.DataParallel(model)
model.cuda()
# load checkpoint if it is possible
start_epoch = it = 0
last_epoch = -1
if args.ckpt is not None:
pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
it, start_epoch = train_utils.load_checkpoint(pure_model, optimizer, filename = args.ckpt, logger = logger)
last_epoch = start_epoch + 1
lr_scheduler, bnm_scheduler = create_scheduler(optimizer, total_steps = len(train_loader) * args.epochs,
last_epoch = last_epoch)
if args.rpn_ckpt is not None:
pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
total_keys = pure_model.state_dict().keys().__len__()
train_utils.load_part_ckpt(pure_model, filename = args.rpn_ckpt, logger = logger, total_keys = total_keys)
if cfg.TRAIN.LR_WARMUP and cfg.TRAIN.OPTIMIZER != 'adam_onecycle':
lr_warmup_scheduler = train_utils.CosineWarmupLR(optimizer, T_max = cfg.TRAIN.WARMUP_EPOCH * len(train_loader),
eta_min = cfg.TRAIN.WARMUP_MIN)
else:
lr_warmup_scheduler = None
# start training
logger.info('**********************Start training**********************')
ckpt_dir = os.path.join(root_result_dir, 'ckpt')
os.makedirs(ckpt_dir, exist_ok = True)
trainer = train_utils.Trainer(
model,
# train_functions.model_joint_fn_decorator(),
fn_decorator,
optimizer,
ckpt_dir = ckpt_dir,
lr_scheduler = lr_scheduler,
bnm_scheduler = bnm_scheduler,
# model_fn_eval=train_functions.model_joint_fn_decorator(),
model_fn_eval = fn_decorator,
tb_log = tb_log,
eval_frequency = 1,
lr_warmup_scheduler = lr_warmup_scheduler,
warmup_epoch = cfg.TRAIN.WARMUP_EPOCH,
grad_norm_clip = cfg.TRAIN.GRAD_NORM_CLIP
)
trainer.train(
it,
start_epoch,
args.epochs,
train_loader,
test_loader,
ckpt_save_interval = args.ckpt_save_interval,
lr_scheduler_each_iter = (cfg.TRAIN.OPTIMIZER == 'adam_onecycle')
)
logger.info('**********************End training**********************')
这一部分是运行结果的输出不用管。
model = PointRCNN(num_classes = train_loader.dataset.num_class, use_xyz = True, mode = 'TRAIN')
这一部分是调用EPNet网络,作者是在PointRCNN基础上改的。
进到这个文件里面
class PointRCNN(nn.Module):
def __init__(self, num_classes, use_xyz = True, mode = 'TRAIN'):
super().__init__()
assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED
if cfg.RPN.ENABLED:
self.rpn = RPN(use_xyz = use_xyz, mode = mode)
if cfg.RCNN.ENABLED:
rcnn_input_channels = 128 # channels of rpn features
if cfg.RCNN.BACKBONE == 'pointnet':
self.rcnn_net = RCNNNet(num_classes = num_classes, input_channels = rcnn_input_channels,
use_xyz = use_xyz)
elif cfg.RCNN.BACKBONE == 'pointsift':
pass
else:
raise NotImplementedError
def forward(self, input_data):
if cfg.RPN.ENABLED:
output = { }
# rpn inference
with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):##判断
if cfg.RPN.FIXED:
self.rpn.eval()
rpn_output = self.rpn(input_data)
output.update(rpn_output)
backbone_xyz = rpn_output['backbone_xyz']
backbone_features = rpn_output['backbone_features']
# rcnn inference
if cfg.RCNN.ENABLED:
with torch.no_grad():
##在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。
rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output['rpn_reg']
rpn_scores_raw = rpn_cls[:, :, 0]
rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
pts_depth = torch.norm(backbone_xyz, p = 2, dim = 2)##求2范数
# proposal layer
rois, roi_scores_raw = self.rpn.proposal_layer(rpn_scores_raw, rpn_reg, backbone_xyz) # (B, M, 7)
output['rois'] = rois
output['roi_scores_raw'] = roi_scores_raw
output['seg_result'] = seg_mask
rcnn_input_info = { 'rpn_xyz' : backbone_xyz,
'rpn_features': backbone_features.permute((0, 2, 1)),
'seg_mask' : seg_mask,
'roi_boxes3d' : rois,
'pts_depth' : pts_depth
}
if self.training:
rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']
rcnn_output = self.rcnn_net(rcnn_input_info)
''''
rcnn_input_info = { 'rpn_xyz' : backbone_xyz,
'rpn_features': backbone_features.permute((0, 2, 1)),
'seg_mask' : seg_mask,
'roi_boxes3d' : rois,
'pts_depth' : pts_depth
'gt_boxes3d' :input_data['gt_boxes3d']
}
'''
output.update(rcnn_output)
elif cfg.RCNN.ENABLED:
output = self.rcnn_net(input_data)
else:
raise NotImplementedError
return output
最后输出得到的是一个包含候选框中心坐标、特征、回归掩码、roi_box、真实框的一个字典 也就是第一阶段+第二阶段的总网络层。
其中
是第一阶段提取特征和生成候选框的网络。
这是第二部分细化候选框的代码