Deteron2 Faster-RCNN 代码阅读笔记
整体结构
在 Detectron2 中 Faster/Mask RCNN 系列是通过 GeneralizedRCNN 来实现的,代码位于 detectron2\modeling\meta_arch\rcnn.py
。类的关系入下图所示
GeneralizedRCNN 由三部分组成 : backbone、proposal_generator 和 roi_heads,分别通过 build_backbone,build_proposal_generator, build_roi_heads
构建。初始化的部分代码如下。目前 Detectron2 支持的 Backbone 有 Resnet50(+FPN), Rest101(+FPN), proposal_generator
由 detectron2\modeling\proposal_generator\rpn.py
中的 RPN
提供, roi_heads
由 detectron2\modeling\roi_heads\roi_heads.py
中的 StandardROIHeads
提供。
class GeneralizedRCNN(nn.Module):
"""
Generalized R-CNN. Any models that contains the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.backbone = build_backbone(cfg)
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
self.vis_period = cfg.VIS_PERIOD
self.input_format = cfg.INPUT.FORMAT
GeneralizedRCNN
的前向传播过程为过程为
输入参数为 bachted_inputs
,类型是字典,字典中包括 image, instances, proposals, height, width
。返回值也是一个字典,包括 pred_boxes, pred_classes, scores, pred_masks, pred_keypoints
.