Caffe2 - (二十八) Detectron 之 modeling - model_builder

Caffe2 - (二十八)Detectron 之 modeling - model_builder

Detectron 支持很多种模型类型,可以进行多种参数配置.

"""
Detectron 模型构建函数
一个给定模型是有:
 - backbone 主干网络,如 VGG16, ResNet, ResNeXt
 - FPN (on or off)
 - PRN only (只用来生成 proposals)
 - 固定 proposals,用于 Fast R-CNN, RFCN, Mask R-CNN(with or without keypoints) (非 end-to-end)
 - End-to-end 模型,如 RPN + Fast R-CNN (i.e., Faster R-CNN), Mask R-CNN, ...
 - 模型可以有不同的 head 选择
 - ... 跟多的配置选择 ...

模型可以通过组合许多基础模块来进行构建,即使某些模型看起来比较复杂,但比较灵活.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import copy
import importlib
import logging

from caffe2.python import core
from caffe2.python import workspace

from core.config import cfg # 参数配置
from modeling.detector import DetectionModelHelper # Detectron 定义的模型创建 Helper
from roi_data.loader import RoIDataLoader  # 
import modeling.fast_rcnn_heads as fast_rcnn_heads 
import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads
import modeling.mask_rcnn_heads as mask_rcnn_heads
import modeling.name_compat
import modeling.optimizer as optim 
import modeling.retinanet_heads as retinanet_heads
import modeling.rfcn_heads as rfcn_heads
import modeling.rpn_heads as rpn_heads
import roi_data.minibatch
import utils.c2 as c2_utils

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------- #
# 通用可组合的模型构建器
# 
# 如,Fast R-CNN model with the ResNet-50-C4 backbone 的模型构建配置:
# MODEL:
#   TYPE: generalized_rcnn
#   CONV_BODY: ResNet.add_ResNet50_conv4_body
#   ROI_HEAD: ResNet.add_ResNet_roi_conv5_head
# ---------------------------------------------------------------------------- #

def generalized_rcnn(model):
    """
    该模型类型可处理:
      - Fast R-CNN
      - RPN only (not integrated with Fast R-CNN)
      - Faster R-CNN (stagewise training from NIPS paper)
      - Faster R-CNN (end-to-end joint training)
      - Mask R-CNN (stagewise training from NIPS paper)
      - Mask R-CNN (end-to-end joint training)
    """
    return build_generic_detection_model(
        model,
        get_func(cfg.MODEL.CONV_BODY),
        add_roi_box_head_func=get_func(cfg.FAST_RCNN.ROI_BOX_HEAD),
        add_roi_mask_head_func=get_func(cfg.MRCNN.ROI_MASK_HEAD),
        add_roi_keypoint_head_func=get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD),
        freeze_conv_body=cfg.TRAIN.FREEZE_CONV_BODY )


def rfcn(model):
    # TODO(rbg): fold into build_generic_detection_model
    return build_generic_rfcn_model(model, get_func(cfg.MODEL.CONV_BODY))


def retinanet(model):
    # TODO(rbg): fold into build_generic_detection_model
    return build_generic_retinanet_model(model, get_func(cfg.MODEL.CONV_BODY))


# ---------------------------------------------------------------------------- #
# 构建不同的 re-usable 网络的 Helper functions
# ---------------------------------------------------------------------------- #

def create(model_type_func, train=False):
    """
    通用模型构建函数,指定模型构建函数.
    """
    model = DetectionModelHelper(name=model_type_func,
                                 train=train,
                                 num_classes=cfg.MODEL.NUM_CLASSES,
                                 init_params=train )
    return get_func(model_type_func)(model)


def get_func(func_name):
    """
    根据 name 返回函数对象function object.
    func_name 必须是该模块里的某个函数或者是想对于 base 'modeling' 模块的函数路径.
    """
    if func_name == '':
        return None
    new_func_name = modeling.name_compat.get_new_name(func_name) # 保持网络名的我兼容性
    if new_func_name != func_name:
        logger.warn('Remapping old function name: {} -> {}'.format(func_name, new_func_name) )
        func_name = new_func_name
    try:
        parts = func_name.split('.')
        # 指向该模块中的一个函数
        if len(parts) == 1:
            return globals()[parts[0]]
        # 否则, 假设已经引用了模型的一个模块
        module_name = 'modeling.' + '.'.join(parts[:-1])
        module = importlib.import_module(module_name)
        return getattr(module, parts[-1])
    except Exception:
        logger.error('Failed to find function: {}'.format(func_name))
        raise


def build_generic_detection_model(model, 
                                  add_conv_body_func,
                                  add_roi_box_head_func=None,
                                  add_roi_mask_head_func=None,
                                  add_roi_keypoint_head_func=None,
                                  freeze_conv_body=False ):
    def _single_gpu_build_func
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值