在之前的博文中介绍过用tf提供的预训练模型进行inference,非常简单。这里我们深入源码,了解检测API的代码架构,每个部分的深入阅读留待后续。
首先官方文档还是比较丰富的,可以先全看一遍,然后和核心的模型有关的文档是:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/defining_your_own_model.md
还有一个比较麻烦的地方是这里使用protobuf文件来管理参数配置,参见:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/configuring_jobs.md
'''构建自己模型的接口是虚基类DetectionModel,具体有5个抽象函数需要实现。
'''
object_detection/core/model.py
def groundtruth_lists(self, field):
"""Access list of groundtruth tensors."""
def groundtruth_has_field(self, field):
"""Determines whether the groundtruth includes the given field."""
def provide_groundtruth(self,
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list=None,
groundtruth_keypoints_list=None):
"""Provide groundtruth tensors."""
@abstractmethod
def preprocess(self, inputs):
@abstractmethod
def predict(self, preprocessed_inputs)
@abstractmethod
def postprocess(self, prediction_dict, **params)
@abstractmethod
def loss(self, prediction_dict)
@abstractmethod
def restore_map(self, from_detection_checkpoint=True)
object_detection/meta_architectures/faster_rcnn_meta_arch.py
class FasterRCNNFeatureExtractor(object):
"""Faster R-CNN Feature Extractor definition."""
def __init__(self,
is_training,
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0)
@abstractmethod
def preprocess(self, resized_inputs):
"""Feature-extractor specific preprocessing (minus image resizing)."""
def extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features."""
@abstractmethod
def _extract_proposal_features(self, preprocessed_inputs, scope):
def extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features."""
@abstractmethod
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features, to be overridden."""
def restore_from_classification_checkpoint_fn(
self,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns a map of variables to load from a foreign checkpoint."""
class FasterRCNNMetaArch(model.DetectionModel):
"""Faster R-CNN Meta-architecture definition."""
"""暂时主要看哪些地方调用了feature_extractor: A FasterRCNNFeatureExtractor object.换一个cnn还是比较简单的,只需要重写一个faster_rcnn_new_cnn_feature_extractor。最终构建的检测模型是这个类的对象。"""
def preprocess(self, inputs):
"""For Faster R-CNN, we perform image resizing in the base class --- each
class subclassing FasterRCNNMetaArch is responsible for any additional
preprocessing (e.g., scaling pixel values to be in [-1, 1]).
见下面代码块中实现的preprocess函数"""
object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
"""这一块和slim结合紧密,我们仔细看看。
"""
class FasterRCNNResnetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Resnet V1 feature extractor implementation."""
def __init__(self,
architecture,
resnet_model,
is_training,
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
def preprocess(self, resized_inputs):
"""Faster R-CNN Resnet V1 preprocessing."""
channel_means = [123.68, 116.779, 103.939]
return resized_inputs - [[channel_means]]
def _extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features.
使用endpoints输出resnet block3的值。
"""
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features.
拆分出resnet的block4。注意variable_scope和arg_scope的使用。
"""
class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
"""Faster R-CNN Resnet 152 feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride,
batch_norm_trainable=False,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
batch_norm_trainable: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet152FeatureExtractor, self).__init__(
'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
first_stage_features_stride, batch_norm_trainable,
reuse_weights, weight_decay)
"""往前看各个类的init,'resnet_v1_152', resnet_v1.resnet_v1_152只用在了上面的class FasterRCNNResnetV1FeatureExtractor"""
同样建议跑一跑test脚本。会遇到如下文件,按照test中出现的顺序逐个阅读这些文件,以及对应的test脚本。
"""Builder function to construct tf-slim arg_scope for convolution, fc ops.
看一下这个脚本的test,很容易理解超参数配置是怎么读取的了,类似OpenFOAM中的dict。object_detection.protos.hyperparams_pb2.Hyperparams。
"""
from object_detection.builders import hyperparams_builder
"""Contains routines for printing protocol messages in text format.
同样是上面这个test脚本,目前主要用在
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
其中conv_hyperparams_text_proto是包含参数配置的字符串,conv_hyperparams_proto是hyperparams.proto object,hyperparams_builder.build的第一个参数。
"""
from google.protobuf import text_format
"""Function to build box predictor from configuration.
Box predictors are classes that take a high level
image feature map as input and produce two predictions,
(1) a tensor encoding box locations, and
(2) a tensor encoding classes for each box.
object_detection/core/box_predictor.py留待后续研读。注意conv_hyperparams_text_proto是放进box_predictor_text_proto然后一起传递给class ConvolutionalBoxPredictor(BoxPredictor)的。
"""
from object_detection.builders import box_predictor_builder
"""Generates grid anchors on the fly as used in Faster RCNN.
下次细看。
"""
from object_detection.anchor_generators import grid_anchor_generator
"""Builder function for post processing operations."""
from object_detection.builders import post_processing_builder
"""Classification and regression loss functions for object detection."""
from object_detection.core import losses
"""proto文件,下次再结合相应的core和builder来具体研究如何编写和读取这些文件"""
from object_detection.protos import box_predictor_pb2
from object_detection.protos import hyperparams_pb2
from object_detection.protos import post_processing_pb2
"""A function to build a DetectionModel from configuration.
很多内容在faster_rcnn_meta_arch_test_lib.py测试过了。
"""
object_detection/builders/model_builder.py