本文是以Faster RCNN
为脉络进行分析。 SDD
等类似吧!!! 我还没看。作为一个菜鸟,阅读代码一般是从第一个文件开始看。在我的思维里,Faster RCNN是从CNN等基层框架中抽取feature map进行检测,所以就想在train.py和trainer.py中想找到loss和输出等函数,好像没有。看到显性的损失函数,或者输出。
So.....,还是先看train.py
在train.py的导入文件有如下:
train.py
#train.py
import functools
import json
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import trainer
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
这里的model_builder引起了我的注意,那就跳到model_builder.py文件吧。model_builder.py这里导入了很多内容
model_builder.py
#model_builder.py
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2
这里的faster_rcnn_inception_resnet_v2_feature_extractor在说明文档里有提到的。请看object_detectionAPI源码阅读笔记
So...faster_rcnn_inception_resnet_v2_feature_extractor.py就是我要的啊,在配置文档里有提到,这个是进行特这提取的。也是DetectionModels (object_detection/core/model.py)
的子类。为什么是子类接下来会有说明。
#model_builder.py
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
那么就从faster_rcnn_inception_v2_feature_extractor.py看导入的文件吧。
faster_rcnn_inception_v2_feature_extractor.py
##faster_rcnn_inception_v2_feature_extractor.py
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from nets import inception_resnet_v2
这里导入了inception_resnet_v2模型,我感觉老祖宗找到了,这里的基本CNN模型就是inception_resnet_v2模型,这是一个基本的CNN框架,nets还有很多基本的网络框架,包括vgg,alexnet等。
所以faster_rcnn_inception_v2_feature_extractor.py看样子就是对基本框架进行提取的文件了。
但是这里导入了faster_rcnn_meta_arch
我们看看faster_rcnn_meta_arch.py
faster_rcnn_meta_arch.py
#faster_rcnn_meta_arch.py
from abc import abstractmethod
from functools import partial
import tensorflow as tf
from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import box_predictor
from object_detection.core import losses
from object_detection.core import model
from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import shape_utils
这里找到了model.py这可是所有检测模型的基类
model.py
#model.py
from abc import ABCMeta
from abc import abstractmethod
from object_detection.core import standard_fields as fields
发现果然是基类,没导入什么内容,这也算是检测模型的老祖宗了,简陋到不行啊。
class DetectionModel(object):
"""Abstract base class for detection models."""
__metaclass__ = ABCMeta
def __init__(self, num_classes):
"""Constructor.
发现基类DetectionModel(object)就是在这里个文件实现的,
基类的功能就是:如下inputs (images tensor) -> preprocess -> predict -> loss ->outputs (loss tensor)
1.目录脉络 train.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> train.py
2.目录脉络 eval.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> eval.py
这里有一张haixwang的图
trianer.py
- _create_losses()
在这个文件中有个专门产生损失(Loss)的函数,请仔细,再仔细
看注释。这里的detection_model介绍请看object_detectionAPI源码阅读笔记(5)和object_detectionAPI源码阅读笔记(6)
def _create_losses(input_queue, create_model_fn, train_config):
"""Creates loss function for a DetectionModel.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
create_model_fn: A function to create the DetectionModel.
train_config: a train_pb2.TrainConfig protobuf.
"""
# 创建一个检测模型
detection_model = create_model_fn()
# 读入数据 使用get_inputs()
(images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs(
input_queue,
detection_model.num_classes,
train_config.merge_multiple_label_boxes)
# 对数据进行归一化
images = [detection_model.preprocess(image) for image in images]
images = tf.concat(images, 0)
if any(mask is None for mask in groundtruth_masks_list):
groundtruth_masks_list = None
if any(keypoints is None for keypoints in groundtruth_keypoints_list):
groundtruth_keypoints_list = None
# 获取真实标签数据
detection_model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list,
groundtruth_keypoints_list)
# 进行预测吧
prediction_dict = detection_model.predict(images)
# 产生损失
losses_dict = detection_model.loss(prediction_dict)
for loss_tensor in losses_dict.values():
tf.losses.add_loss(loss_tensor)
这里的_create_losses()产生了loss会被送入到train进行优化训练。
- train()
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
is_chief, train_dir):
"""Training function for detection models.
Args:
create_tensor_dict_fn: 创建输入张量函数
create_model_fn:a function that creates a DetectionModel and generates losses.(创建一个损失函数)
train_config: 训练配置文件
master: 分布式训练设别的名字
task: The task id of this training instance.
num_clones: The number of clones to run per machine.
worker_replicas: The number of work replicas to train with.
clone_on_cpu: True if clones should be forced to run on CPU.
ps_tasks: Number of parameter server tasks.
worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica.
train_dir: 训练文件的保存目录
"""
到这里,估计差不多了训练流程走的差不都了,这里实现的trainer.train()是最后的配置。我也看到loss了(参数create_model_fn:a function that creates a DetectionModel and generates losses)
。
eval.py 中用到DetectionModel
预测总体过程:inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
eval.py 中导入如下包
# eval.py
import functools
import os
import tensorflow as tf
import evaluator
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import label_map_util
与train类似,其中的evaluator才是DetectionModel真正使用者。
- _extract_prediction_tensors()
def _extract_prediction_tensors(model,
create_input_dict_fn,
ignore_groundtruth=False):
"""Restores the model in a tensorflow session.
Args:
model: model to perform predictions with.
create_input_dict_fn: function to create input tensor dictionaries.
ignore_groundtruth: whether groundtruth should be ignored.
Returns:
tensor_dict: A tensor dictionary with evaluations.
"""
# 创建数据输入队列
input_dict = create_input_dict_fn()
prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
input_dict = prefetch_queue.dequeue()
original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0)
# 创建检测模型
preprocessed_image = model.preprocess(tf.to_float(original_image))
# 进行预测
prediction_dict = model.predict(preprocessed_image)
# 进行后处理
detections = model.postprocess(prediction_dict)
# 获取这是标签
groundtruth = None
if not ignore_groundtruth:
groundtruth = {
fields.InputDataFields.groundtruth_boxes:
input_dict[fields.InputDataFields.groundtruth_boxes],
fields.InputDataFields.groundtruth_classes:
input_dict[fields.InputDataFields.groundtruth_classes],
fields.InputDataFields.groundtruth_area:
input_dict[fields.InputDataFields.groundtruth_area],
fields.InputDataFields.groundtruth_is_crowd:
input_dict[fields.InputDataFields.groundtruth_is_crowd],
fields.InputDataFields.groundtruth_difficult:
input_dict[fields.InputDataFields.groundtruth_difficult]
}
if fields.InputDataFields.groundtruth_group_of in input_dict:
groundtruth[fields.InputDataFields.groundtruth_group_of] = (
input_dict[fields.InputDataFields.groundtruth_group_of])
if fields.DetectionResultFields.detection_masks in detections:
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
input_dict[fields.InputDataFields.groundtruth_instance_masks])
return eval_util.result_dict_for_single_example(
original_image,
input_dict[fields.InputDataFields.source_id],
detections,
groundtruth,
class_agnostic=(
fields.DetectionResultFields.detection_classes not in detections),
scale_to_absolute=True)
这里关于检测模型的详细内容请继续阅读吧!!!