1.损失函数在
bbox_head中定义,类型为CenterNetHead
具体相关代码
bbox_head=dict( type='CenterNetHead', num_classes=5, in_channel=64, feat_channel=64, loss_center_heatmap=dict(type='GaussianFocalLoss', loss_weight=1.0), loss_wh=dict(type='L1Loss', loss_weight=0.1), loss_offset=dict(type='L1Loss', loss_weight=1.0)),
具体有关loss_wh的计算,类型为L1Loss,函数定义在model中,losses下的smooth_l1_loss.py文件,具体调用损失函数为L1Loss
本质为
pred和target做差
在计算loss_wh时,只计算center存在点的loss,这一实现依靠weight权重
weight的计算在CenterNetHead类中实现,实现函数为
def get_targets(self, gt_bboxes, gt_labels, feat_shape, img_shape):
核心代码如下:
wh_offset_target_weight = gt_bboxes[-1].new_zeros( [bs, 2, feat_h, feat_w])
wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1
初始化全0的weight的tensor,
只把中心点对应处的weight改为1。