前言:
本文详解的是在maskrcnn_benchmark代码中,RoI层中的边框预测模块的损失函数计算代码。在本文详解的loss.py覆盖了预测边框筛选函数,通过该函数可以排除出原预测边框中不符合标准的边框,重新选择背景边框和目标边框,并使用这个边框构成的新预测边框来计算loss值。其代码详解为:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import (
BalancedPositiveNegativeSampler
)
from maskrcnn_benchmark.modeling.utils import cat
class FastRCNNLossComputation(object):
"""
Computes the loss for Faster R-CNN.
Also supports FPN
"""
def __init__(
self,
proposal_matcher,
fg_bg_sampler,
box_coder,
cls_agnostic_bbox_reg=False
):
"""
Arguments:
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
# todo 计算出所有预测边框所对应的基准边框(groun truth box),并返回对应的列表
def match_targets_to_proposals(self, proposal, target):
# 计算基准边框与预测边框相互之间的IoU
match_quality_matrix = boxlist_iou(target, proposal)
# 计算各个预测边框对应的基准边框(ground truth box)的索引列表,背景边框为-2,模糊边框为-1
matched_idxs = self.proposal_matcher(match_quality_matrix)
# Fast RCNN only need "labels" field for selecting the targets
# 获得基准边框(gr