之前做完比赛过后计划看看mmdetection的源码写点blog,写了两篇过后忙其他事去了,这里就接着把之前没写完的东西补上。
之前写了模型和网络的创建,这里就主要写下训练过程中具体的loss,主要分为以下几部分
- RPN_loss
- bbox_loss
- mask_loss
RPN_loss
rpn_loss
的实现具体定义在mmdet/models/anchor_head/rpn_head.py
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore=None):
losses = super(RPNHead, self).loss(
cls_scores,
bbox_preds,
gt_bboxes,
None,
img_metas,
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
具体的计算方式定义在其父类mmdet/models/anchor_heads/anchor_head.py
,主要是loss
和loss_single
两个函数。
先看loss
函数
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is