【mmdetection】FCOS
损失函数
下面我们再来看看函数里面定义的损失函数。
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
'''
cls_scores: [5][batchsize,80,H_i,W_i]
bbox_preds: [5][batchsize,4,H_i,W_i]
centernesses: [5][batchsize,1,H_i,W_i]
gt_bboxes: [batchsize][num_obj,4]
gt_labels: [batchsize][num_obj]
img_metas: [batchsize][(dict)dict_keys(['filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg'])]
cfg: {'assigner': {'type': 'MaxIoUAssigner', 'pos_iou_thr': 0.5, 'neg_iou_thr': 0.4, 'min_pos_iou': 0, 'ignore_iof_thr': -1}, 'allowed_border': -1, 'pos_weight': -1, 'debug': False}
'''
assert len(cls_scores) == len(bbox_preds) == len(centernesses) # 5
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # P3-P7特征图的大小
'''
[torch.Size([100, 152]),
torch.Size([50, 76]),
torch.Size([25, 38]),
torch.Size([13, 19]),
torch.Size([7, 10])]
'''
# 特征图的大小就相当于把原图分为多大的grid,特征图每个像素映射到原图就是该grid的中心点,不同大小的特征图就有不同的grid
# bbox_preds[0].dtype:torch.float32
# all_level_points:(list) [5][n_points][2]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
gt_labels)
'''
labels:[5][batch_size*level_points_i]
bbox_targets:[5][batch_size*level_points_i,4]
'''
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores) # torch.Size([89600, 80]) 所有图片所有point的5个层的输出
flatten_bbox_preds = torch.cat(flatten_bbox_preds) # torch.Size([89600, 4])
flatten_centerness = torch.cat(flatten_centerness) # torch.Size([89600])
flatten_labels = torch.cat(labels) # torch.Size([89600])
flatten_bbox_targets = torch.cat(b