前言:
在经过RPN层之后,网络会生成多个预测边框(proposal), 这时候需要对这些边框进行RoI池化,使之成为尺度一致的特征。接下来就需要对这些特征进行进一步的特征提取,这就需要用到roi_box_feature_extractors.py。roi_box_feature_extractors.py定义了三种不同的特种提取方式:ResNet卷基层方式、MIL全连接方式以及使用多个卷基层组+全连接方式。其代码详解为:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.modeling.make_layers import make_fc
# 使用ResNet50的Conv5层来提取roi特征
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module):
def __init__(self, config, in_channels):
super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
# resolution为roi pooling之后特征图的大小,一般为7
resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
# 获得原始图到特征图的比例函数,比如原始图到Res50的stage2是1/4
scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES
# sampling_ratio即采样率,指的是锚点大小与池化之后特征图的大小比例。一般情况下不指定
sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
# 初始化池化类,内含ROIAlign函数
pooler =