前言:
在对RPN预测到的边框进行进一步特征提取后,需要对边框进行预测,得到边框的类别和位置大小信息。这一操作在maskrcnn_benchmark中由roi_box_predictors.py完成,该文件实现了两种预测类:直接进行预测以及先池化再预测。其代码详解如下:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.modeling import registry
from torch import nn
# todo 现将预测边框的特征进行池化,再使用边框预测结构和边框回归结构来预测边框的类别以及边框的坐标偏差值
@registry.ROI_BOX_PREDICTOR.register("FastRCNNPredictor")
class FastRCNNPredictor(nn.Module):
def __init__(self, config, in_channels):
super(FastRCNNPredictor, self).__init__()
# 当输入层的通道为空时报错
assert in_channels is not None
# 输入层的通道数
num_inputs = in_channels
# 得到基准边框的类别数,一般都要加上一类为背景
num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
# 对输入层特征先进行池化
self.avgpool = nn.AdaptiveAvgPool2d(1)
# 创建用于预测边框类别的网络结构:线性链接层,类别数