object_detectionAPI源码阅读笔记(3-train.py)

 

本文是以Faster RCNN为脉络进行分析。 SDD等类似吧!!! 我还没看。作为一个菜鸟,阅读代码一般是从第一个文件开始看。在我的思维里,Faster RCNN是从CNN等基层框架中抽取feature map进行检测,所以就想在train.py和trainer.py中想找到loss和输出等函数,好像没有。看到显性的损失函数,或者输出。
So.....,还是先看train.py

在train.py的导入文件有如下:

train.py

#train.py
import functools
import json
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import trainer
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util

这里的model_builder引起了我的注意,那就跳到model_builder.py文件吧。model_builder.py这里导入了很多内容

model_builder.py

#model_builder.py
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2

这里的faster_rcnn_inception_resnet_v2_feature_extractor在说明文档里有提到的。请看object_detectionAPI源码阅读笔记
So...faster_rcnn_inception_resnet_v2_feature_extractor.py就是我要的啊,在配置文档里有提到,这个是进行特这提取的。也是DetectionModels (object_detection/core/model.py)的子类。为什么是子类接下来会有说明。

#model_builder.py
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor

那么就从faster_rcnn_inception_v2_feature_extractor.py看导入的文件吧。

faster_rcnn_inception_v2_feature_extractor.py

##faster_rcnn_inception_v2_feature_extractor.py

import tensorflow as tf

from object_detection.meta_architectures import faster_rcnn_meta_arch
from nets import inception_resnet_v2

这里导入了inception_resnet_v2模型,我感觉老祖宗找到了,这里的基本CNN模型就是inception_resnet_v2模型,这是一个基本的CNN框架,nets还有很多基本的网络框架,包括vgg,alexnet等。

所以faster_rcnn_inception_v2_feature_extractor.py看样子就是对基本框架进行提取的文件了。

但是这里导入了faster_rcnn_meta_arch
我们看看faster_rcnn_meta_arch.py

faster_rcnn_meta_arch.py

#faster_rcnn_meta_arch.py
from abc import abstractmethod
from functools import partial
import tensorflow as tf

from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import box_predictor
from object_detection.core import losses
from object_detection.core import model
from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import shape_utils

这里找到了model.py这可是所有检测模型的基类

model.py

#model.py
from abc import ABCMeta
from abc import abstractmethod

from object_detection.core import standard_fields as fields

发现果然是基类,没导入什么内容,这也算是检测模型的老祖宗了,简陋到不行啊。

class DetectionModel(object):
  """Abstract base class for detection models."""
  __metaclass__ = ABCMeta

  def __init__(self, num_classes):
    """Constructor.

发现基类DetectionModel(object)就是在这里个文件实现的,
基类的功能就是:如下
inputs (images tensor) -> preprocess -> predict -> loss ->outputs (loss tensor)

1.目录脉络 train.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> train.py

2.目录脉络 eval.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> eval.py
这里有一张haixwang的图

 

trianer.py

def _create_losses(input_queue, create_model_fn, train_config):
    """Creates loss function for a DetectionModel.

    Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    create_model_fn: A function to create the DetectionModel.
    train_config: a train_pb2.TrainConfig protobuf.
    """
    
    # 创建一个检测模型
    detection_model = create_model_fn()
    
    # 读入数据 使用get_inputs()
    (images, _, groundtruth_boxes_list, groundtruth_classes_list,
    groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs(
       input_queue,
       detection_model.num_classes,
       train_config.merge_multiple_label_boxes)
    
    # 对数据进行归一化
    images = [detection_model.preprocess(image) for image in images]
    images = tf.concat(images, 0)
    if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None
    if any(keypoints is None for keypoints in groundtruth_keypoints_list):
    groundtruth_keypoints_list = None
    
    # 获取真实标签数据
    detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
                                      groundtruth_masks_list,
                                      groundtruth_keypoints_list)
                                     
    # 进行预测吧 
    prediction_dict = detection_model.predict(images)

    # 产生损失
    losses_dict = detection_model.loss(prediction_dict)
    for loss_tensor in losses_dict.values():
    tf.losses.add_loss(loss_tensor)

这里的_create_losses()产生了loss会被送入到train进行优化训练。

  • train()

def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: 创建输入张量函数
    create_model_fn:a function that creates a DetectionModel and generates losses.(创建一个损失函数) 
    train_config: 训练配置文件
    master: 分布式训练设别的名字
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: 训练文件的保存目录
  """

到这里,估计差不多了训练流程走的差不都了,这里实现的trainer.train()是最后的配置。我也看到loss了(参数create_model_fn:a function that creates a DetectionModel and generates losses)

eval.py 中用到DetectionModel

预测总体过程:inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
eval.py 中导入如下包

# eval.py 
import functools
import os
import tensorflow as tf

import evaluator
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import label_map_util

与train类似,其中的evaluator才是DetectionModel真正使用者。

  • _extract_prediction_tensors()

def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
  """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
  # 创建数据输入队列
  input_dict = create_input_dict_fn()
  prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
  input_dict = prefetch_queue.dequeue()
  original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0)
  
  # 创建检测模型
  preprocessed_image = model.preprocess(tf.to_float(original_image))
  
  # 进行预测
  prediction_dict = model.predict(preprocessed_image)
  
  # 进行后处理
  detections = model.postprocess(prediction_dict)

  # 获取这是标签
  groundtruth = None
  if not ignore_groundtruth:
    groundtruth = {
        fields.InputDataFields.groundtruth_boxes:
            input_dict[fields.InputDataFields.groundtruth_boxes],
        fields.InputDataFields.groundtruth_classes:
            input_dict[fields.InputDataFields.groundtruth_classes],
        fields.InputDataFields.groundtruth_area:
            input_dict[fields.InputDataFields.groundtruth_area],
        fields.InputDataFields.groundtruth_is_crowd:
            input_dict[fields.InputDataFields.groundtruth_is_crowd],
        fields.InputDataFields.groundtruth_difficult:
            input_dict[fields.InputDataFields.groundtruth_difficult]
    }
    if fields.InputDataFields.groundtruth_group_of in input_dict:
      groundtruth[fields.InputDataFields.groundtruth_group_of] = (
          input_dict[fields.InputDataFields.groundtruth_group_of])
    if fields.DetectionResultFields.detection_masks in detections:
      groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
          input_dict[fields.InputDataFields.groundtruth_instance_masks])

  return eval_util.result_dict_for_single_example(
      original_image,
      input_dict[fields.InputDataFields.source_id],
      detections,
      groundtruth,
      class_agnostic=(
          fields.DetectionResultFields.detection_classes not in detections),
      scale_to_absolute=True)

这里关于检测模型的详细内容请继续阅读吧!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值