Caffe2 - (二十八)Detectron 之 modeling - model_builder
Detectron 支持很多种模型类型,可以进行多种参数配置.
"""
Detectron 模型构建函数
一个给定模型是有:
- backbone 主干网络,如 VGG16, ResNet, ResNeXt
- FPN (on or off)
- PRN only (只用来生成 proposals)
- 固定 proposals,用于 Fast R-CNN, RFCN, Mask R-CNN(with or without keypoints) (非 end-to-end)
- End-to-end 模型,如 RPN + Fast R-CNN (i.e., Faster R-CNN), Mask R-CNN, ...
- 模型可以有不同的 head 选择
- ... 跟多的配置选择 ...
模型可以通过组合许多基础模块来进行构建,即使某些模型看起来比较复杂,但比较灵活.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
import importlib
import logging
from caffe2.python import core
from caffe2.python import workspace
from core.config import cfg # 参数配置
from modeling.detector import DetectionModelHelper # Detectron 定义的模型创建 Helper
from roi_data.loader import RoIDataLoader #
import modeling.fast_rcnn_heads as fast_rcnn_heads
import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads
import modeling.mask_rcnn_heads as mask_rcnn_heads
import modeling.name_compat
import modeling.optimizer as optim
import modeling.retinanet_heads as retinanet_heads
import modeling.rfcn_heads as rfcn_heads
import modeling.rpn_heads as rpn_heads
import roi_data.minibatch
import utils.c2 as c2_utils
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------- #
# 通用可组合的模型构建器
#
# 如,Fast R-CNN model with the ResNet-50-C4 backbone 的模型构建配置:
# MODEL:
# TYPE: generalized_rcnn
# CONV_BODY: ResNet.add_ResNet50_conv4_body
# ROI_HEAD: ResNet.add_ResNet_roi_conv5_head
# ---------------------------------------------------------------------------- #
def generalized_rcnn(model):
"""
该模型类型可处理:
- Fast R-CNN
- RPN only (not integrated with Fast R-CNN)
- Faster R-CNN (stagewise training from NIPS paper)
- Faster R-CNN (end-to-end joint training)
- Mask R-CNN (stagewise training from NIPS paper)
- Mask R-CNN (end-to-end joint training)
"""
return build_generic_detection_model(
model,
get_func(cfg.MODEL.CONV_BODY),
add_roi_box_head_func=get_func(cfg.FAST_RCNN.ROI_BOX_HEAD),
add_roi_mask_head_func=get_func(cfg.MRCNN.ROI_MASK_HEAD),
add_roi_keypoint_head_func=get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD),
freeze_conv_body=cfg.TRAIN.FREEZE_CONV_BODY )
def rfcn(model):
# TODO(rbg): fold into build_generic_detection_model
return build_generic_rfcn_model(model, get_func(cfg.MODEL.CONV_BODY))
def retinanet(model):
# TODO(rbg): fold into build_generic_detection_model
return build_generic_retinanet_model(model, get_func(cfg.MODEL.CONV_BODY))
# ---------------------------------------------------------------------------- #
# 构建不同的 re-usable 网络的 Helper functions
# ---------------------------------------------------------------------------- #
def create(model_type_func, train=False):
"""
通用模型构建函数,指定模型构建函数.
"""
model = DetectionModelHelper(name=model_type_func,
train=train,
num_classes=cfg.MODEL.NUM_CLASSES,
init_params=train )
return get_func(model_type_func)(model)
def get_func(func_name):
"""
根据 name 返回函数对象function object.
func_name 必须是该模块里的某个函数或者是想对于 base 'modeling' 模块的函数路径.
"""
if func_name == '':
return None
new_func_name = modeling.name_compat.get_new_name(func_name) # 保持网络名的我兼容性
if new_func_name != func_name:
logger.warn('Remapping old function name: {} -> {}'.format(func_name, new_func_name) )
func_name = new_func_name
try:
parts = func_name.split('.')
# 指向该模块中的一个函数
if len(parts) == 1:
return globals()[parts[0]]
# 否则, 假设已经引用了模型的一个模块
module_name = 'modeling.' + '.'.join(parts[:-1])
module = importlib.import_module(module_name)
return getattr(module, parts[-1])
except Exception:
logger.error('Failed to find function: {}'.format(func_name))
raise
def build_generic_detection_model(model,
add_conv_body_func,
add_roi_box_head_func=None,
add_roi_mask_head_func=None,
add_roi_keypoint_head_func=None,
freeze_conv_body=False ):
def _single_gpu_build_func