转载自:https://blog.csdn.net/weixin_38362784/article/details/111479263,勿喷。
CenterPoint 在mmdetection3d中的实现
写在前面
# mmdet3d/models/detectors/centerpoint.py
class CenterPoint(MVXTwoStageDetector):
"""Base class of Multi-modality VoxelNet."""
...
def forward_pts_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.pts_bbox_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.pts_bbox_head.loss(*loss_inputs)
return losses
此处的self.pts_bbox_head,在配置文件中设置为CenterHead。
因此,主要分析CenterHead中的forward函数和loss函数。
一、CenterHead的forward函数
【待补充】
二、CenterHead的loss函数
Step0: 参数说明
# mmdet3d/models/dense_heads/centerpoint_head.py
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
Step1: 对真值进行处理
1. 简述:
根据gt_bboxes_3d和gt_labels_3d,生成各task的热图、框尺寸等信息。
2. loss函数中的实现:
# mmdet3d/models/dense_heads/centerpoint_head.py
# loss function
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)
3. 关键函数:
mmdet3d/models/dense_heads/centerpoint_head.py
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
3-1 参数说明
3-2 中间变量说明
根据配置文件中的大类,一共有6个task
task_masks
按照类别划分,记录各类别目标在gt_bboxes_3d/gt_labels_3d中的坐标ID:
task_boxes
记录各task中的真实框参数。
task_classes
重新排序各task中的真实框类别。(0是背景)
3-3 针对每个Task,生成heatmap、anno_box、ind、mask
遍历该Task内的所有目标,更新上述四个变量。
3-3-1 heatmap的更新
draw_gaussian(heatmap[cls_id], center_int, radius)
参数说明
cls_id 决定在哪一张热图上更新
center_int 记录中心点在热图上的位置(x, y)
radius 决定高斯核大小
结果举例
3-3-2 anno_box的更新
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
- 第1-2维表示中心点的偏移量offset_x offset_y
热图上的坐标(x, y)是离散整型,实际的中心点有精确到小数的偏移。 - 第3维表示中心点的高度z
- 第4-6维表示目标框的长宽高box_dim
- 第7-8维表示旋转角度sin(α) cos(α)
- 第9-10维表示速度vx vy
nuScenes数据集有速度数据,如需使用KITTI数据集,需要更改部分代码。
3-4 返回值说明
heatmaps, anno_boxes, inds, masks 均是长度为6的数组,保存6个task的内容。
至此,“Step1: 对真值进行处理“ 已经完成。
Step2: 损失值计算
1. 简述:
根据preds_dicts和Step1得到的heatmaps, anno_boxes, inds, masks ,分别计算每一个task的loss_heatmap和loss_bbox。
2. preds_dicts说明:
preds_dicts包含6个元素,分别是6个task的预测结果。下表以preds_dicts[0]举例:
3. 计算loss_heatmap:GaussianFocalLoss
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # 预测得到的热图 [BS, cls_num, 128, 128]
heatmaps[task_id], # 实际的热图 [BS, cls_num, 128, 128]
avg_factor=max(num_pos, 1)) # num_pos表示实际目标的数量
此处self.loss_cls是GaussianFocalLoss,该损失函数的实现见:
mmdetection/mmdet/models/losses/gaussian_focal_loss.py
4. 计算loss_bbox:
loss_bbox = self.loss_bbox(
pred, # 预测 torch.Size([BS, 500, 10])
target_box, # 真值 torch.Size([BS, 500, 10])
bbox_weights, # torch.Size([BS, 500, 10]) 第2维表示mask 第3维表示权重
avg_factor=(num + 1e-4))
此处self.loss_bbox是L1Loss,该损失函数的实现见:
mmdetection/mmdet/models/losses/smooth_l1_loss.py
Step3: 返回值说明【总结】
最终得到的loss_dict举例:
task0.loss_heatmap: 1.0833, task0.loss_bbox: 0.5410,
task1.loss_heatmap: 1.2952, task1.loss_bbox: 0.5907,
task2.loss_heatmap: 1.1385, task2.loss_bbox: 0.5840,
task3.loss_heatmap: 1.0866, task3.loss_bbox: 0.4793,
task4.loss_heatmap: 1.1322, task4.loss_bbox: 0.5697,
task5.loss_heatmap: 1.2827, task5.loss_bbox: 0.6070