Tensorflow object detection API 源码阅读笔记:架构

在之前的博文中介绍过用tf提供的预训练模型进行inference,非常简单。这里我们深入源码,了解检测API的代码架构,每个部分的深入阅读留待后续。

首先官方文档还是比较丰富的,可以先全看一遍,然后和核心的模型有关的文档是:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/defining_your_own_model.md
还有一个比较麻烦的地方是这里使用protobuf文件来管理参数配置,参见:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/configuring_jobs.md

'''构建自己模型的接口是虚基类DetectionModel,具体有5个抽象函数需要实现。
'''
object_detection/core/model.py
  def groundtruth_lists(self, field):
    """Access list of groundtruth tensors."""
  def groundtruth_has_field(self, field):
    """Determines whether the groundtruth includes the given field."""
  def provide_groundtruth(self,
                          groundtruth_boxes_list,
                          groundtruth_classes_list,
                          groundtruth_masks_list=None,
                          groundtruth_keypoints_list=None):
    """Provide groundtruth tensors."""

  @abstractmethod
  def preprocess(self, inputs):

  @abstractmethod
  def predict(self, preprocessed_inputs)

  @abstractmethod
  def postprocess(self, prediction_dict, **params)

  @abstractmethod
  def loss(self, prediction_dict)

  @abstractmethod
  def restore_map(self, from_detection_checkpoint=True)
object_detection/meta_architectures/faster_rcnn_meta_arch.py

class FasterRCNNFeatureExtractor(object):
  """Faster R-CNN Feature Extractor definition."""
  def __init__(self,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0)

  @abstractmethod
  def preprocess(self, resized_inputs):
    """Feature-extractor specific preprocessing (minus image resizing)."""

  def extract_proposal_features(self, preprocessed_inputs, scope):
    """Extracts first stage RPN features."""
  @abstractmethod
  def _extract_proposal_features(self, preprocessed_inputs, scope):

  def extract_box_classifier_features(self, proposal_feature_maps, scope):
    """Extracts second stage box classifier features."""
  @abstractmethod
  def _extract_box_classifier_features(self, proposal_feature_maps, scope):
    """Extracts second stage box classifier features, to be overridden."""

  def restore_from_classification_checkpoint_fn(
      self,
      first_stage_feature_extractor_scope,
      second_stage_feature_extractor_scope):
    """Returns a map of variables to load from a foreign checkpoint."""

class FasterRCNNMetaArch(model.DetectionModel):
  """Faster R-CNN Meta-architecture definition."""
  """暂时主要看哪些地方调用了feature_extractor: A FasterRCNNFeatureExtractor object.换一个cnn还是比较简单的,只需要重写一个faster_rcnn_new_cnn_feature_extractor。最终构建的检测模型是这个类的对象。"""

  def preprocess(self, inputs):
  """For Faster R-CNN, we perform image resizing in the base class --- each
    class subclassing FasterRCNNMetaArch is responsible for any additional
    preprocessing (e.g., scaling pixel values to be in [-1, 1]).
    见下面代码块中实现的preprocess函数"""

object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
"""这一块和slim结合紧密,我们仔细看看。
"""

class FasterRCNNResnetV1FeatureExtractor(
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
  """Faster R-CNN Resnet V1 feature extractor implementation."""
    def __init__(self,
               architecture,
               resnet_model,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0):

    def preprocess(self, resized_inputs):
    """Faster R-CNN Resnet V1 preprocessing."""
        channel_means = [123.68, 116.779, 103.939]
        return resized_inputs - [[channel_means]]

    def _extract_proposal_features(self, preprocessed_inputs, scope):
    """Extracts first stage RPN features.
    使用endpoints输出resnet block3的值。
    """

    def _extract_box_classifier_features(self, proposal_feature_maps, scope):
    """Extracts second stage box classifier features.
    拆分出resnet的block4。注意variable_scope和arg_scope的使用。
    """

class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
  """Faster R-CNN Resnet 152 feature extractor implementation."""

  def __init__(self,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0):
    """Constructor.
    Args:
      is_training: See base class.
      first_stage_features_stride: See base class.
      batch_norm_trainable: See base class.
      reuse_weights: See base class.
      weight_decay: See base class.
    Raises:
      ValueError: If `first_stage_features_stride` is not 8 or 16,
        or if `architecture` is not supported.
    """
    super(FasterRCNNResnet152FeatureExtractor, self).__init__(
        'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
        first_stage_features_stride, batch_norm_trainable,
        reuse_weights, weight_decay)
    """往前看各个类的init,'resnet_v1_152', resnet_v1.resnet_v1_152只用在了上面的class FasterRCNNResnetV1FeatureExtractor"""

同样建议跑一跑test脚本。会遇到如下文件,按照test中出现的顺序逐个阅读这些文件,以及对应的test脚本。

"""Builder function to construct tf-slim arg_scope for convolution, fc ops.
看一下这个脚本的test,很容易理解超参数配置是怎么读取的了,类似OpenFOAM中的dict。object_detection.protos.hyperparams_pb2.Hyperparams。
"""
from object_detection.builders import hyperparams_builder

"""Contains routines for printing protocol messages in text format.
同样是上面这个test脚本,目前主要用在    
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
其中conv_hyperparams_text_proto是包含参数配置的字符串,conv_hyperparams_proto是hyperparams.proto object,hyperparams_builder.build的第一个参数。
"""
from google.protobuf import text_format

"""Function to build box predictor from configuration.
Box predictors are classes that take a high level
image feature map as input and produce two predictions,
(1) a tensor encoding box locations, and
(2) a tensor encoding classes for each box.
object_detection/core/box_predictor.py留待后续研读。注意conv_hyperparams_text_proto是放进box_predictor_text_proto然后一起传递给class ConvolutionalBoxPredictor(BoxPredictor)的。
"""
from object_detection.builders import box_predictor_builder

"""Generates grid anchors on the fly as used in Faster RCNN.
下次细看。
"""
from object_detection.anchor_generators import grid_anchor_generator

"""Builder function for post processing operations."""
from object_detection.builders import post_processing_builder

"""Classification and regression loss functions for object detection."""
from object_detection.core import losses

"""proto文件,下次再结合相应的core和builder来具体研究如何编写和读取这些文件"""
from object_detection.protos import box_predictor_pb2
from object_detection.protos import hyperparams_pb2
from object_detection.protos import post_processing_pb2
"""A function to build a DetectionModel from configuration.
很多内容在faster_rcnn_meta_arch_test_lib.py测试过了。
"""
object_detection/builders/model_builder.py
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值