论文地址:https://arxiv.org/abs/2012.04355v3
项目地址:https://github.com/THU17cyz/3DIoUMatch-PVRCNN
本篇论文的关键代码:
- Dataset:
KittiDatasetSSL
- Model:
PVRCNN_SSL_3DIOU
(本篇文章介绍)
一、原始PVRCNN
- PVRCNN的所有模块:
MeanVFE, VoxelBackBone8xm, HeightCompressionm, VoxelSetAbstraction, BaseBEVBackbone, AnchorHeadSingle, PointHeadSimple, PVRCNNHead。- 说明:
后三个模块涉及损失计算,因此在半监督实验中,teacher和student操作略有不同。(teacher不会调用损失计算函数)- 说明2:
首先记录一下PVRCNN网络。
1. 前向传播forward
AnchorHeadSingle【检测头网络】:
self.forward_ret_dict
更新:用于损失计算。- Prediction:
- cls_preds: [2, 200, 176, 18]
- box_preds: [2, 200, 176, 42]
- dir_cls_preds: [2, 200, 176, 12]
- Target:
- box_cls_labels: [2, 211200] // 211200=200*176*18/3
- box_reg_targets: [2, 211200, 7]
- reg_weights: [2, 211200]
- Prediction:
data_dict
更新:用于保存预测框,为获取Proposal做准备。- batch_cls_preds: [2, 211200, 3]
- batch_box_preds: [2, 211200, 7]
PointHeadSimple【点云特征】:
self.forward_ret_dict
更新:用于损失计算。- point_cls_preds: [4096, 1]
- point_cls_labels: [4096]
data_dict
更新:用于保存点云特征,为Refinement作准备。- point_cls_scores:[4096]
PVRCNNHead【检测头网络】:
self.forward_ret_dict
更新:用于损失计算。- Prediction:
- rcnn_cls, rcnn_reg
- 其他信息:
- rois, gt_of_rois, gt_iou_of_rois, roi_scores, roi_labels, reg_valid_mask, rcnn_cls_labels, gt_of_rois_src
- Prediction:
data_dict
更新:- rois, roi_scores, roi_labels: 根据batch_cls_preds, batch_box_preds和ground truth获得roi区域,用于计算最终框。(在nolabel数据中,需要剔除从gt获得roi的部分)
- batch_cls_preds: [2, 128, 1]
- batch_box_preds: [2, 128, 7]
2. 损失计算
三个模块的损失和
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
二、半监督PVRCNN (Training Mode)
1. 模型构建和加载
构建
__init__
中实现。包括两个结构完全一样的模型
- self.pv_rcnn:参数通过后向传播优化。
- self.pv_rcnn_ema:参数从计算图中分离,通过指数滑动平均优化,在
update_global_step
中更新。
加载
load_params_from_file
中实现。两个模块pv_rcnn和pv_rcnn_ema加载相同参数。
2. 主要流程
- 网络输入:
data_dict
包含Labeled/Unlabeled的强/弱数据增强结果。 - Teacher前向传播:网络输入为Unlabeled的弱数据增强数据,得到预测结果
batch_cls_preds: [2, 100, 1],batch_box_preds [2, 100, 7]
- Teacher后处理:根据上面两个预测结果,得到预测框
pred_boxes [22, 7], pred_scores [22], pred_labels [22]
作为pseudo box。 - pseudo box处理:filter筛选,数据增强对齐,作为Unlabeled数据的真值。即
batch_dict['gt_boxes'][1]
更新为pseudo box。 - pseudo box处理2:论文的LHS模型,应该与NMS起到类似作用。未细看。
- Student前向传播
- 损失计算:半监督损失计算有
scalar
参数,默认返回标量损失值(全监督设置),该训练中返回各batch分别的损失值。目的是对label/unlabel损失设置权重。
附: Teacher前向传播forward
disable_gt_roi_when_pseudo_labeling
开关控制Teacher中独特的计算。
Student前向传播与原始PVRCNN一致。
AnchorHeadSingle【检测头网络】:
self.forward_ret_dict
更新:用于损失计算。- Prediction:
- cls_preds: [2, 200, 176, 18]
- box_preds: [2, 200, 176, 42]
- dir_cls_preds: [2, 200, 176, 12]
- Target: 因为无需计算损失,所以这部分未存储。
- Prediction:
data_dict
更新:用于保存预测框,为获取Proposal做准备。- batch_cls_preds: [2, 211200, 3]
- batch_box_preds: [2, 211200, 7]
PointHeadSimple【点云特征】:
self.forward_ret_dict
更新:用于损失计算。- point_cls_preds: [4096, 1]
- point_cls_labels: 因为无需计算损失,所以这部分未存储。
data_dict
更新:用于保存点云特征,为Refinement作准备。- point_cls_scores:[4096]
PVRCNNHead【检测头网络】:
self.forward_ret_dict
更新:用于损失计算。- Prediction:
- rcnn_cls, rcnn_reg
- 其他信息:
- rois, roi_scores, roi_labels 因为无需计算损失,其他信息未存储。
- Prediction:
data_dict
更新:- rois, roi_scores, roi_labels: 根据batch_cls_preds, batch_box_preds获得roi区域,用于计算最终框。(剔除从gt获得roi的部分)
- batch_cls_preds: [2, 128, 1]
- batch_box_preds: [2, 128, 7]