Tensorflow object detection API源码分析之如何读数据

 

 

核心代码:

object_detection/train.py

关于TFRecord文件的位置一般在samples/configs/*.config中

举例:samples/configs/ssd_mobilenet_v1_coco.config

train_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}

utils/config_util.py从config文件中读取各个配置

def create_configs_from_pipeline_proto(pipeline_config):
  """Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig.

  Args:
    pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_configs`. Value are
      the corresponding config objects or list of config objects (only for
      eval_input_configs).
  """
  configs = {}
  configs["model"] = pipeline_config.model
  configs["train_config"] = pipeline_config.train_config
  configs["train_input_config"] = pipeline_config.train_input_reader
  configs["eval_config"] = pipeline_config.eval_config
  configs["eval_input_configs"] = pipeline_config.eval_input_reader
  # Keeps eval_input_config only for backwards compatibility. All clients should
  # read eval_input_configs instead.
  if configs["eval_input_configs"]:
    configs["eval_input_config"] = configs["eval_input_configs"][0]
  if pipeline_config.HasField("graph_rewriter"):
    configs["graph_rewriter_config"] = pipeline_config.graph_rewriter

  return configs

train.py读取数据

  input_config = configs['train_input_config']
  def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()
  #tf.data.Iterator,一个迭代器
  create_input_dict_fn = functools.partial(get_next, input_config)

builders/dataset_builder.py

整个流程

  1. 初始化需要解码的字段,以及解码对应字段的 handler
  2. 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 解码数据,预提取input_reader_config.prefetch_size 条数据
  3. 对数据集应用 tf.contrib.data.padded_batch_and_drop_remainder,如果不够一个 batch_size 就丢弃该部分数据
  4. 返回一个迭代器

core/standard_fields.py

class InputDataFields(object):
  """Names for the input tensors.

  Holds the standard data field names to use for identifying input tensors. This
  should be used by the decoder to identify keys for the returned tensor_dict
  containing input tensors. And it should be used by the model to identify the
  tensors it needs.

  Attributes:
    image: image.
    image_additional_channels: additional channels.
    original_image: image in the original input size.
    key: unique key corresponding to image.
    source_id: source of the original image.
    filename: original filename of the dataset (without common path).
    groundtruth_image_classes: image-level class labels.
    groundtruth_boxes: coordinates of the ground truth boxes in the image.
    groundtruth_classes: box-level class labels.
    groundtruth_label_types: box-level label types (e.g. explicit negative).
    groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead]
      is the groundtruth a single object or a crowd.
    groundtruth_area: area of a groundtruth segment.
    groundtruth_difficult: is a `difficult` object
    groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the
      same class, forming a connected group, where instances are heavily
      occluding each other.
    proposal_boxes: coordinates of object proposal boxes.
    proposal_objectness: objectness score of each proposal.
    groundtruth_instance_masks: ground truth instance masks.
    groundtruth_instance_boundaries: ground truth instance boundaries.
    groundtruth_instance_classes: instance mask-level class labels.
    groundtruth_keypoints: ground truth keypoints.
    groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
    groundtruth_label_scores: groundtruth label scores.
    groundtruth_weights: groundtruth weight factor for bounding boxes.
    num_groundtruth_boxes: number of groundtruth boxes.
    true_image_shapes: true shapes of images in the resized images, as resized
      images can be padded with zeros.
    verified_labels: list of human-verified image-level labels (note, that a
      label can be verified both as positive and negative).
    multiclass_scores: the label score per class for each box.
  """
  image = 'image'
  image_additional_channels = 'image_additional_channels'
  original_image = 'original_image'
  key = 'key'
  source_id = 'source_id'
  filename = 'filename'
  groundtruth_image_classes = 'groundtruth_image_classes'
  groundtruth_boxes = 'groundtruth_boxes'
  groundtruth_classes = 'groundtruth_classes'
  groundtruth_label_types = 'groundtruth_label_types'
  groundtruth_is_crowd = 'groundtruth_is_crowd'
  groundtruth_area = 'groundtruth_area'
  groundtruth_difficult = 'groundtruth_difficult'
  groundtruth_group_of = 'groundtruth_group_of'
  proposal_boxes = 'proposal_boxes'
  proposal_objectness = 'proposal_objectness'
  groundtruth_instance_masks = 'groundtruth_instance_masks'
  groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
  groundtruth_instance_classes = 'groundtruth_instance_classes'
  groundtruth_keypoints = 'groundtruth_keypoints'
  groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
  groundtruth_label_scores = 'groundtruth_label_scores'
  groundtruth_weights = 'groundtruth_weights'
  num_groundtruth_boxes = 'num_groundtruth_boxes'
  true_image_shape = 'true_image_shape'
  verified_labels = 'verified_labels'
  multiclass_scores = 'multiclass_scores'

object_detection/builders/dataset_builder.py 

# 返回 dataset.output_shapes 各个 key 对应的 shape
def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
                        spatial_image_shape=None):
  """Returns shapes to pad dataset tensors to before batching.

  Args:
    dataset: tf.data.Dataset object.
    max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
      padding.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image.

  Returns:
    A dictionary keyed by fields.InputDataFields containing padding shapes for
    tensors in the dataset.

  Raises:
    ValueError: If groundtruth classes is neither rank 1 nor rank 2.
  """

  if not spatial_image_shape or spatial_image_shape == [-1, -1]:
    height, width = None, None
  else:
    height, width = spatial_image_shape  # pylint: disable=unpacking-non-sequence

  padding_shapes = {
      fields.InputDataFields.image: [height, width, 3],
      fields.InputDataFields.source_id: [],
      fields.InputDataFields.filename: [],
      fields.InputDataFields.key: [],
      fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
      fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
      fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
                                                          width],
      fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
      fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
      fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
      fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
      fields.InputDataFields.groundtruth_area: [max_num_boxes],
      fields.InputDataFields.groundtruth_weights: [max_num_boxes],
      fields.InputDataFields.num_groundtruth_boxes: [],
      fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
      fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
      fields.InputDataFields.true_image_shape: [3],
      fields.InputDataFields.multiclass_scores: [
          max_num_boxes, num_classes + 1 if num_classes is not None else None],
  }
  # Determine whether groundtruth_classes are integers or one-hot encodings, and
  # apply batching appropriately.
  classes_shape = dataset.output_shapes[
      fields.InputDataFields.groundtruth_classes]
  if len(classes_shape) == 1:  # Class integers.
    padding_shapes[fields.InputDataFields.groundtruth_classes] = [max_num_boxes]
  elif len(classes_shape) == 2:  # One-hot or k-hot encoding.
    padding_shapes[fields.InputDataFields.groundtruth_classes] = [
        max_num_boxes, num_classes]
  else:
    raise ValueError('Groundtruth classes must be a rank 1 tensor (classes) or '
                     'rank 2 tensor (one-hot encodings)')

  if fields.InputDataFields.original_image in dataset.output_shapes:
    padding_shapes[fields.InputDataFields.original_image] = [None, None, 3]
  if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
    tensor_shape = dataset.output_shapes[fields.InputDataFields.
                                         groundtruth_keypoints]
    padding_shape = [max_num_boxes, tensor_shape[1].value,
                     tensor_shape[2].value]
    padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
  if (fields.InputDataFields.groundtruth_keypoint_visibilities
      in dataset.output_shapes):
    tensor_shape = dataset.output_shapes[fields.InputDataFields.
                                         groundtruth_keypoint_visibilities]
    padding_shape = [max_num_boxes, tensor_shape[1].value]
    padding_shapes[fields.InputDataFields.
                   groundtruth_keypoint_visibilities] = padding_shape
  return {tensor_key: padding_shapes[tensor_key]
          for tensor_key, _ in dataset.output_shapes.items()}

def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None,
          num_additional_channels=0):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If None, batching is not performed.
    max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
      padding. If None, will use a dynamic shape.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. If None, will use a dynamic shape.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. If None, will use dynamic shapes.
    num_additional_channels: Number of additional channels to use in the input.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  '''
  train_input_reader: {
  tf_record_input_reader {
    input_path: "data/wider_udlr.record"
  }
  label_map_path: "data/udlr_face.pbtxt"
}
  '''
  #input_reader_config必须是input_reader_pb2.InputReader类型对象
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    #input_path,读取tfrecord文件地址
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      #读取label地址
      label_map_proto_file = input_reader_config.label_map_path
    #初始化需要解码的字段,以及解码对应字段的 handler
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file,
        use_display_name=input_reader_config.use_display_name,
        num_additional_channels=num_additional_channels)

    # 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 解码
    # 数据,预提取 input_reader_config.prefetch_size 条数据
    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    dataset = dataset_util.read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        process_fn, config.input_path[:], input_reader_config)

    if batch_size:
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')
def make_initializable_iterator(dataset):
  """Creates an iterator, and initializes tables.

  This is useful in cases where make_one_shot_iterator wouldn't work because
  the graph contains a hash table that needs to be initialized.

  Args:
    dataset: A `tf.data.Dataset` object.

  Returns:
    A `tf.data.Iterator`.
  """
  iterator = dataset.make_initializable_iterator()
  tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
  return iterator

# 调用 file_read_func 从 input_files 中读数据
# 调用 decode_func 将读到的数据解码
# 预提取 config.prefetch_size
def read_dataset(file_read_func, decode_func, input_files, config):
  """Reads a dataset, and handles repetition and shuffling.

  Args:
    file_read_func: Function to use in tf.data.Dataset.interleave, to read
      every individual file into a tf.data.Dataset.
    decode_func: Function to apply to all records.
    input_files: A list of file paths to read.
    config: A input_reader_builder.InputReader object.

  Returns:
    A tf.data.Dataset based on config.
  """
  # Shard, shuffle, and read files.
  filenames = tf.gfile.Glob(input_files)
  num_readers = config.num_readers
  if num_readers > len(filenames):
    num_readers = len(filenames)
    tf.logging.warning('num_readers has been reduced to %d to match input file '
                       'shards.' % num_readers)
  filename_dataset = tf.data.Dataset.from_tensor_slices(tf.unstack(filenames))
  if config.shuffle:
    filename_dataset = filename_dataset.shuffle(
        config.filenames_shuffle_buffer_size)
  elif num_readers > 1:
    tf.logging.warning('`shuffle` is false, but the input data stream is '
                       'still slightly shuffled since `num_readers` > 1.')

  filename_dataset = filename_dataset.repeat(config.num_epochs or None)

  records_dataset = filename_dataset.apply(
      tf.contrib.data.parallel_interleave(
          file_read_func,
          cycle_length=num_readers,
          block_length=config.read_block_length,
          sloppy=config.shuffle))
  if config.shuffle:
    records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
  tensor_dataset = records_dataset.map(
      decode_func, num_parallel_calls=config.num_parallel_map_calls)
  return tensor_dataset.prefetch(config.prefetch_size)

protos/input_reader.proto

syntax = "proto2";

package object_detection.protos;

// Configuration proto for defining input readers that generate Object Detection
// Examples from input sources. Input readers are expected to generate a
// dictionary of tensors, with the following fields populated:
//
// 'image': an [image_height, image_width, channels] image tensor that detection
//    will be run on.
// 'groundtruth_classes': a [num_boxes] int32 tensor storing the class
//    labels of detected boxes in the image.
// 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of
//    detected boxes in the image.
// 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height,
//    image_width] float tensor storing binary mask of the objects in boxes.

// Instance mask format. Note that PNG masks are much more space efficient.
enum InstanceMaskType {
  DEFAULT = 0;          // Default implementation, currently NUMERICAL_MASKS
  NUMERICAL_MASKS = 1;  // [num_masks, H, W] float32 binary masks.
  PNG_MASKS = 2;        // Encoded PNG masks.
}

message InputReader {
  // Path to StringIntLabelMap pbtxt file specifying the mapping from string
  // labels to integer ids.
  optional string label_map_path = 1 [default=""];

  // Whether data should be processed in the order they are read in, or
  // shuffled randomly.
  optional bool shuffle = 2 [default=true];

  // Buffer size to be used when shuffling.
  optional uint32 shuffle_buffer_size = 11 [default = 2048];

  // Buffer size to be used when shuffling file names.
  optional uint32 filenames_shuffle_buffer_size = 12 [default = 100];

  // Maximum number of records to keep in reader queue.
  optional uint32 queue_capacity = 3 [default=2000];

  // Minimum number of records to keep in reader queue. A large value is needed
  // to generate a good random shuffle.
  optional uint32 min_after_dequeue = 4 [default=1000];

  // The number of times a data source is read. If set to zero, the data source
  // will be reused indefinitely.
  optional uint32 num_epochs = 5 [default=0];

  // Number of reader instances to create.
  optional uint32 num_readers = 6 [default=32];

  // Number of records to read from each reader at once.
  optional uint32 read_block_length = 15 [default=32];

  // Number of decoded records to prefetch before batching.
  optional uint32 prefetch_size = 13 [default = 512];

  // Number of parallel decode ops to apply.
  optional uint32 num_parallel_map_calls = 14 [default = 64];

  // Number of groundtruth keypoints per object.
  optional uint32 num_keypoints = 16 [default = 0];

  // Whether to load groundtruth instance masks.
  optional bool load_instance_masks = 7 [default = false];

  // Type of instance mask.
  optional InstanceMaskType mask_type = 10 [default = NUMERICAL_MASKS];

  // Whether to use the display name when decoding examples. This is only used
  // when mapping class text strings to integers.
  optional bool use_display_name = 17 [default = false];

  oneof input_reader {
    TFRecordInputReader tf_record_input_reader = 8;
    ExternalInputReader external_input_reader = 9;
  }
}

// An input reader that reads TF Example protos from local TFRecord files.
message TFRecordInputReader {
  // Path(s) to `TFRecordFile`s.
  repeated string input_path = 1;
}

// An externally defined input reader. Users may define an extension to this
// proto to interface their own input readers.
message ExternalInputReader {
  extensions 1 to 999;
}

data_decoders/tf_example_decoder.py

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tensorflow Example proto decoder for object detection.

A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow as tf

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util

slim_example_decoder = tf.contrib.slim.tfexample_decoder


# TODO(lzc): keep LookupTensor and BackupHandler in sync with
# tf.contrib.slim.tfexample_decoder version.
class LookupTensor(slim_example_decoder.Tensor):
  """An ItemHandler that returns a parsed Tensor, the result of a lookup."""

  def __init__(self,
               tensor_key,
               table,
               shape_keys=None,
               shape=None,
               default_value=''):
    """Initializes the LookupTensor handler.

    Simply calls a vocabulary (most often, a label mapping) lookup.

    Args:
      tensor_key: the name of the `TFExample` feature to read the tensor from.
      table: A tf.lookup table.
      shape_keys: Optional name or list of names of the TF-Example feature in
        which the tensor shape is stored. If a list, then each corresponds to
        one dimension of the shape.
      shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
        reshaped accordingly.
      default_value: The value used when the `tensor_key` is not found in a
        particular `TFExample`.

    Raises:
      ValueError: if both `shape_keys` and `shape` are specified.
    """
    self._table = table
    super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
                                       default_value)

  def tensors_to_item(self, keys_to_tensors):
    unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
    return self._table.lookup(unmapped_tensor)


class BackupHandler(slim_example_decoder.ItemHandler):
  """An ItemHandler that tries two ItemHandlers in order."""

  def __init__(self, handler, backup):
    """Initializes the BackupHandler handler.

    If the first Handler's tensors_to_item returns a Tensor with no elements,
    the second Handler is used.

    Args:
      handler: The primary ItemHandler.
      backup: The backup ItemHandler.

    Raises:
      ValueError: if either is not an ItemHandler.
    """
    if not isinstance(handler, slim_example_decoder.ItemHandler):
      raise ValueError('Primary handler is of type %s instead of ItemHandler' %
                       type(handler))
    if not isinstance(backup, slim_example_decoder.ItemHandler):
      raise ValueError(
          'Backup handler is of type %s instead of ItemHandler' % type(backup))
    self._handler = handler
    self._backup = backup
    super(BackupHandler, self).__init__(handler.keys + backup.keys)

  def tensors_to_item(self, keys_to_tensors):
    item = self._handler.tensors_to_item(keys_to_tensors)
    return control_flow_ops.cond(
        pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
        true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
        false_fn=lambda: item)


class TfExampleDecoder(data_decoder.DataDecoder):
  """Tensorflow Example proto decoder."""
  #初始化指定需要解码哪些字段,以及对应的handler
  def __init__(self,
               load_instance_masks=False,
               instance_mask_type=input_reader_pb2.NUMERICAL_MASKS,
               label_map_proto_file=None,
               use_display_name=False,
               dct_method='',
               num_keypoints=0,
               num_additional_channels=0):
    """Constructor sets keys_to_features and items_to_handlers.

    Args:
      load_instance_masks: whether or not to load and handle instance masks.
      instance_mask_type: type of instance masks. Options are provided in
        input_reader.proto. This is only used if `load_instance_masks` is True.
      label_map_proto_file: a file path to a
        object_detection.protos.StringIntLabelMap proto. If provided, then the
        mapped IDs of 'image/object/class/text' will take precedence over the
        existing 'image/object/class/label' ID.  Also, if provided, it is
        assumed that 'image/object/class/text' will be in the data.
      use_display_name: whether or not to use the `display_name` for label
        mapping (instead of `name`).  Only used if label_map_proto_file is
        provided.
      dct_method: An optional string. Defaults to None. It only takes
        effect when image format is jpeg, used to specify a hint about the
        algorithm used for jpeg decompression. Currently valid values
        are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
        example, the jpeg library does not have that specific option.
      num_keypoints: the number of keypoints per object.
      num_additional_channels: how many additional channels to use.

    Raises:
      ValueError: If `instance_mask_type` option is not one of
        input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or
        input_reader_pb2.PNG_MASKS.
    """
    self.keys_to_features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/key/sha256':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/source_id':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=1),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=1),
        # Object boxes and classes.
        'image/object/bbox/xmin':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax':
            tf.VarLenFeature(tf.float32),
        'image/object/class/label':
            tf.VarLenFeature(tf.int64),
        'image/object/class/text':
            tf.VarLenFeature(tf.string),
        'image/object/area':
            tf.VarLenFeature(tf.float32),
        'image/object/is_crowd':
            tf.VarLenFeature(tf.int64),
        'image/object/difficult':
            tf.VarLenFeature(tf.int64),
        'image/object/group_of':
            tf.VarLenFeature(tf.int64),
        'image/object/weight':
            tf.VarLenFeature(tf.float32),
    }
    # We are checking `dct_method` instead of passing it directly in order to
    # ensure TF version 1.6 compatibility.
    # 如果是 JPEG 图片,可以加速。
    if dct_method:
      image = slim_example_decoder.Image(
          image_key='image/encoded',
          format_key='image/format',
          channels=3,
          dct_method=dct_method)
      additional_channel_image = slim_example_decoder.Image(
          image_key='image/additional_channels/encoded',
          format_key='image/format',
          channels=1,
          repeated=True,
          dct_method=dct_method)
    else:
        # image 对象主要用来解码图片,
        # 如果 image/format 为 raw 或 Raw 调用 parsing_ops.decode_raw 解码
        # 如果 image/format 为 jpeg,调用 image_ops.decode_jpeg 解码
        # 否则调用 image_ops.decode_image 解码图片
      image = slim_example_decoder.Image(
          image_key='image/encoded', format_key='image/format', channels=3)
      additional_channel_image = slim_example_decoder.Image(
          image_key='image/additional_channels/encoded',
          format_key='image/format',
          channels=1,
          repeated=True)
    self.items_to_handlers = {
        fields.InputDataFields.image:
            image,
        fields.InputDataFields.source_id: (
            slim_example_decoder.Tensor('image/source_id')),
        fields.InputDataFields.key: (
            slim_example_decoder.Tensor('image/key/sha256')),
        fields.InputDataFields.filename: (
            slim_example_decoder.Tensor('image/filename')),
        # Object boxes and classes.
        fields.InputDataFields.groundtruth_boxes: (
            slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
                                             'image/object/bbox/')),
        fields.InputDataFields.groundtruth_area:
            slim_example_decoder.Tensor('image/object/area'),
        fields.InputDataFields.groundtruth_is_crowd: (
            slim_example_decoder.Tensor('image/object/is_crowd')),
        fields.InputDataFields.groundtruth_difficult: (
            slim_example_decoder.Tensor('image/object/difficult')),
        fields.InputDataFields.groundtruth_group_of: (
            slim_example_decoder.Tensor('image/object/group_of')),
        fields.InputDataFields.groundtruth_weights: (
            slim_example_decoder.Tensor('image/object/weight')),
    }
    if num_additional_channels > 0:
      self.keys_to_features[
          'image/additional_channels/encoded'] = tf.FixedLenFeature(
              (num_additional_channels,), tf.string)
      self.items_to_handlers[
          fields.InputDataFields.
          image_additional_channels] = additional_channel_image
    self._num_keypoints = num_keypoints
    if num_keypoints > 0:
      self.keys_to_features['image/object/keypoint/x'] = (
          tf.VarLenFeature(tf.float32))
      self.keys_to_features['image/object/keypoint/y'] = (
          tf.VarLenFeature(tf.float32))
      self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = (
          slim_example_decoder.ItemHandlerCallback(
              ['image/object/keypoint/y', 'image/object/keypoint/x'],
              self._reshape_keypoints))
    if load_instance_masks:
      if instance_mask_type in (input_reader_pb2.DEFAULT,
                                input_reader_pb2.NUMERICAL_MASKS):
        self.keys_to_features['image/object/mask'] = (
            tf.VarLenFeature(tf.float32))
        self.items_to_handlers[
            fields.InputDataFields.groundtruth_instance_masks] = (
                slim_example_decoder.ItemHandlerCallback(
                    ['image/object/mask', 'image/height', 'image/width'],
                    self._reshape_instance_masks))
      elif instance_mask_type == input_reader_pb2.PNG_MASKS:
        self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
        self.items_to_handlers[
            fields.InputDataFields.groundtruth_instance_masks] = (
                slim_example_decoder.ItemHandlerCallback(
                    ['image/object/mask', 'image/height', 'image/width'],
                    self._decode_png_instance_masks))
      else:
        raise ValueError('Did not recognize the `instance_mask_type` option.')
    if label_map_proto_file:
      label_map = label_map_util.get_label_map_dict(label_map_proto_file,
                                                    use_display_name)
      # We use a default_value of -1, but we expect all labels to be contained
      # in the label map.
      table = tf.contrib.lookup.HashTable(
          initializer=tf.contrib.lookup.KeyValueTensorInitializer(
              keys=tf.constant(list(label_map.keys())),
              values=tf.constant(list(label_map.values()), dtype=tf.int64)),
          default_value=-1)
      # If the label_map_proto is provided, try to use it in conjunction with
      # the class text, and fall back to a materialized ID.
      # TODO(lzc): note that here we are using BackupHandler defined in this
      # file(which is branching slim_example_decoder.BackupHandler). Need to
      # switch back to slim_example_decoder.BackupHandler once tf 1.5 becomes
      # more popular.
      label_handler = BackupHandler(
          LookupTensor('image/object/class/text', table, default_value=''),
          slim_example_decoder.Tensor('image/object/class/label'))
    else:
      label_handler = slim_example_decoder.Tensor('image/object/class/label')
    self.items_to_handlers[
        fields.InputDataFields.groundtruth_classes] = label_handler

  # decode 方法调用 parsing_ops.parse_single_example 解码对应的字段,调用对应字段的 hander 解码
  def decode(self, tf_example_string_tensor):
    """Decodes serialized tensorflow example and returns a tensor dictionary.

    Args:
      tf_example_string_tensor: a string tensor holding a serialized tensorflow
        example proto.

    Returns:
      A dictionary of the following tensors.
      fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
        containing image.
      fields.InputDataFields.source_id - string tensor containing original
        image id.
      fields.InputDataFields.key - string tensor with unique sha256 hash key.
      fields.InputDataFields.filename - string tensor with original dataset
        filename.
      fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
        [None, 4] containing box corners.
      fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
        [None] containing classes for the boxes.
      fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
        shape [None] indicating the weights of groundtruth boxes.
      fields.InputDataFields.num_groundtruth_boxes - int32 scalar indicating
        the number of groundtruth_boxes.
      fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
        [None] containing containing object mask area in pixel squared.
      fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
        [None] indicating if the boxes enclose a crowd.

    Optional:
      fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
        shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
        is width; 3rd dim is the number of additional channels.
      fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
        [None] indicating if the boxes represent `difficult` instances.
      fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
        [None] indicating if the boxes represent `group_of` instances.
      fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
        shape [None, None, 2] containing keypoints, where the coordinates of
        the keypoints are ordered (y, x).
      fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
        shape [None, None, None] containing instance masks.
    """
    # 创建一个 TFExampleDecoder 对象,初始化需要解码的字段以及对应字段的 handler
    serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
                                                    self.items_to_handlers)
    keys = decoder.list_items()
    # 调用 parsing_ops.parse_single_example 将  String 类型的 Tensor 进行解码。
    # 需要解码的字段由 keys 指定
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))
    is_crowd = fields.InputDataFields.groundtruth_is_crowd
    tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
    tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
    tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
        tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]

    if fields.InputDataFields.image_additional_channels in tensor_dict:
      channels = tensor_dict[fields.InputDataFields.image_additional_channels]
      channels = tf.squeeze(channels, axis=3)
      channels = tf.transpose(channels, perm=[1, 2, 0])
      tensor_dict[fields.InputDataFields.image_additional_channels] = channels

    def default_groundtruth_weights():
      return tf.ones(
          [tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
          dtype=tf.float32)

    tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
        tf.greater(
            tf.shape(
                tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
            0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
        default_groundtruth_weights)
    return tensor_dict

  def _reshape_keypoints(self, keys_to_tensors):
    """Reshape keypoints.

    The instance segmentation masks are reshaped to [num_instances,
    num_keypoints, 2].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values
        in {0, 1}.
    """
    y = keys_to_tensors['image/object/keypoint/y']
    if isinstance(y, tf.SparseTensor):
      y = tf.sparse_tensor_to_dense(y)
    y = tf.expand_dims(y, 1)
    x = keys_to_tensors['image/object/keypoint/x']
    if isinstance(x, tf.SparseTensor):
      x = tf.sparse_tensor_to_dense(x)
    x = tf.expand_dims(x, 1)
    keypoints = tf.concat([y, x], 1)
    keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
    return keypoints

  def _reshape_instance_masks(self, keys_to_tensors):
    """Reshape instance segmentation masks.

    The instance segmentation masks are reshaped to [num_instances, height,
    width].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, height, width] with values
        in {0, 1}.
    """
    height = keys_to_tensors['image/height']
    width = keys_to_tensors['image/width']
    to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
    masks = keys_to_tensors['image/object/mask']
    if isinstance(masks, tf.SparseTensor):
      masks = tf.sparse_tensor_to_dense(masks)
    masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
    return tf.cast(masks, tf.float32)

  def _decode_png_instance_masks(self, keys_to_tensors):
    """Decode PNG instance segmentation masks and stack into dense tensor.

    The instance segmentation masks are reshaped to [num_instances, height,
    width].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, height, width] with values
        in {0, 1}.
    """

    def decode_png_mask(image_buffer):
      image = tf.squeeze(
          tf.image.decode_image(image_buffer, channels=1), axis=2)
      image.set_shape([None, None])
      image = tf.to_float(tf.greater(image, 0))
      return image

    png_masks = keys_to_tensors['image/object/mask']
    height = keys_to_tensors['image/height']
    width = keys_to_tensors['image/width']
    if isinstance(png_masks, tf.SparseTensor):
      png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
    return tf.cond(
        tf.greater(tf.size(png_masks), 0),
        lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
        lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))

参考:https://blog.csdn.net/wenxueliu/article/details/80727563

TensorFlow Object Detection API 是一个开源项目,它提供了一系列基于 TensorFlow 的工具和库,用于实现目标检测任务。对于 macOS 系统,我们可以通过以下步骤来使用 TensorFlow Object Detection API: 1. 安装 TensorFlow:在 macOS 上安装 TensorFlow 是使用 TensorFlow Object Detection API 的前提。你可以通过 pip 命令进行安装,例如在终端中执行 `pip install tensorflow`。 2. 下载 TensorFlow Object Detection API:打开终端并导航到适合你的工作目录中,然后使用 git 命令来克隆 TensorFlow Object Detection API 的 GitHub 仓库,例如执行 `git clone https://github.com/tensorflow/models.git`。 3. 安装依赖项:进入克隆的模型目录中,找到 research 文件夹并进入。然后运行 `pip install -r object_detection/requirements.txt` 命令来安装所需的依赖项。 4. 下载预训练模型:在 TensorFlow Object Detection API 中,我们可以使用预训练的模型来进行目标检测。你可以从 TensorFlow Model Zoo 中下载适合你任务的模型,并将其解压到你的工作目录中。 5. 运行实例代码:在 research/object_detection 目录中,你可以找到一些示例代码,用于训练、评估和使用目标检测模型。可以通过阅读这些示例代码并根据自己的需求进行修改。例如,你可以使用 `python object_detection/builders/model_builder_tf2_test.py` 命令来运行一个模型的测试。 以上是在 macOS 上使用 TensorFlow Object Detection API 的基本步骤,你可以根据你的具体需求进行更多的深入研究和调整。希望这些信息能帮助到你!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值