CenterPoint 在mmdetection3d中的实现

CenterPoint 在mmdetection3d中的实现

模型以如下配置文件为例:
configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py
MMDetection3d官方模型:
CenterPoint (继承自MVXTwoStageDetector
该博客主要分析关键代码:
CenterHead

写在前面

# mmdet3d/models/detectors/centerpoint.py
class CenterPoint(MVXTwoStageDetector):
    """Base class of Multi-modality VoxelNet."""
    ...
    def forward_pts_train(self,
                          pts_feats,
                          gt_bboxes_3d,
                          gt_labels_3d,
                          img_metas,
                          gt_bboxes_ignore=None):
       
        outs = self.pts_bbox_head(pts_feats)
        loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
        losses = self.pts_bbox_head.loss(*loss_inputs)
        return losses

此处的self.pts_bbox_head,在配置文件中设置为CenterHead
因此,主要分析CenterHead中的forward函数和loss函数。

一、CenterHeadforward函数

【待补充】

二、CenterHeadloss函数

Step0: 参数说明

# mmdet3d/models/dense_heads/centerpoint_head.py
 def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
参数gt_bboxes_3dgt_labels_3dpreds_dicts
说明保存真值:框的参数保存真值:框的类别forward函数的输出
类型list[:obj:LiDARInstance3DBoxes]list[torch.Tensor]dict
备注列表长度表示 batch_size列表长度表示 batch_size包含6个元素,分别是6个task的预测结果
举例元素举例:tensor([0, 9, 0 …], device=‘cuda:0’)将在后续详细说明

Step1: 对真值进行处理

1. 简述:

根据gt_bboxes_3dgt_labels_3d,生成各task的热图、框尺寸等信息。

2. loss函数中的实现:
# mmdet3d/models/dense_heads/centerpoint_head.py
# loss function
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)
3. 关键函数:
# mmdet3d/models/dense_heads/centerpoint_head.py
 def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
3-1 参数说明
参数gt_bboxes_3dgt_labels_3d
说明保存真值:框的参数保存真值:框的类别
类型obj:LiDARInstance3DBoxestorch.Tensor
取值举例经处理后,可得到实际框的参数tensor([0, 9, 0 …], device=‘cuda:0’)
尺寸举例经处理后,torch.Size([76, 9])torch.Size([76])
3-2 中间变量说明

根据配置文件中的大类,一共有6个task

  • task_masks
    按照类别划分,记录各类别目标在gt_bboxes_3d/gt_labels_3d中的坐标ID:
    在这里插入图片描述 在这里插入图片描述
  • task_boxes
    记录各task中的真实框参数。
  • task_classes
    重新排序各task中的真实框类别。(0是背景)
    在这里插入图片描述 在这里插入图片描述
3-3 针对每个Task,生成heatmapanno_boxindmask
参数heatmapanno_boxindmask
说明中心点热图框的参数框的中心点在heatmap中的位置前obj_num个元素置1,obj_num表示框的个数
尺寸[class_num, 128, 128][500, 10][500][500]
取值举例每个class有一张热图10维参数的含义,见下ind[idx] = x*128 + ymask[idx] = 1

遍历该Task内的所有目标,更新上述四个变量。

3-3-1 heatmap的更新
draw_gaussian(heatmap[cls_id], center_int, radius)

参数说明

  • cls_id 决定在哪一张热图上更新
  • center_int 记录中心点在热图上的位置(x, y)
  • radius 决定高斯核大小

结果举例
在这里插入图片描述

3-3-2 anno_box的更新
anno_box[new_idx] = torch.cat([
                        center - torch.tensor([x, y], device=device),
                        z.unsqueeze(0), box_dim,
                        torch.sin(rot).unsqueeze(0),
                        torch.cos(rot).unsqueeze(0),
                        vx.unsqueeze(0),
                        vy.unsqueeze(0)
                    ])
  1. 第1-2维表示中心点的偏移量offset_x offset_y
    热图上的坐标(x, y)是离散整型,实际的中心点有精确到小数的偏移。
  2. 第3维表示中心点的高度z
  3. 第4-6维表示目标框的长宽高box_dim
  4. 第7-8维表示旋转角度sin(α) cos(α)
  5. 第9-10维表示速度vx vy
    nuScenes数据集有速度数据,如需使用KITTI数据集,需要更改部分代码。
3-4 返回值说明

heatmaps, anno_boxes, inds, masks 均是长度为6的数组,保存6个task的内容。

至此,“Step1: 对真值进行处理“ 已经完成。

Step2: 损失值计算

1. 简述:

根据preds_dicts和Step1得到的heatmaps, anno_boxes, inds, masks ,分别计算每一个task的loss_heatmaploss_bbox

2. preds_dicts说明:

在这里插入图片描述
preds_dicts包含6个元素,分别是6个task的预测结果。下表以preds_dicts[0]举例:

preds_dicts[0] KEYdimheatmapheightregrotvel
尺寸[batch_size, 3, 128, 128][batch_size, 1, 128, 128][batch_size, 1, 128, 128][batch_size, 2, 128, 128][batch_size, 2, 128, 128][batch_size, 2, 128, 128]
说明表示目标框的长宽高热图表示中心点的高度表示中心点的偏移量表示旋转角度表示速度
对应Step1得到的真值anno_box第4-6维box_dimheatmapanno_box第3维zanno_box第1-2维offset_x offset_yanno_box第7-8维sin(α) cos(α)anno_box第9-10维vx vy
3. 计算loss_heatmap:GaussianFocalLoss
loss_heatmap = self.loss_cls(
                preds_dict[0]['heatmap'],	# 预测得到的热图 [BS, cls_num, 128, 128]
                heatmaps[task_id],			# 实际的热图 [BS, cls_num, 128, 128]
                avg_factor=max(num_pos, 1))	# num_pos表示实际目标的数量

此处self.loss_cls是GaussianFocalLoss,该损失函数的实现见:
mmdetection/mmdet/models/losses/gaussian_focal_loss.py

4. 计算loss_bbox
loss_bbox = self.loss_bbox(
                pred,					# 预测 torch.Size([BS, 500, 10])
                target_box,				# 真值 torch.Size([BS, 500, 10])
                bbox_weights,			# torch.Size([BS, 500, 10]) 第2维表示mask 第3维表示权重
                avg_factor=(num + 1e-4))

此处self.loss_bbox是L1Loss,该损失函数的实现见:
mmdetection/mmdet/models/losses/smooth_l1_loss.py

Step3: 返回值说明【总结】

最终得到的loss_dict举例:

task0.loss_heatmap: 1.0833, task0.loss_bbox: 0.5410, 
task1.loss_heatmap: 1.2952, task1.loss_bbox: 0.5907, 
task2.loss_heatmap: 1.1385, task2.loss_bbox: 0.5840, 
task3.loss_heatmap: 1.0866, task3.loss_bbox: 0.4793, 
task4.loss_heatmap: 1.1322, task4.loss_bbox: 0.5697, 
task5.loss_heatmap: 1.2827, task5.loss_bbox: 0.6070
  • 8
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值