Tensorflow 物体检测(object detection) 之如何优雅地预处理数据

核心代码

本文主要是对 models/research/object_detection 的分析

samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config

train_config: {
 data_augmentation_options {
    random_horizontal_flip {
    }
  }
}

trainer.py

def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
                       batch_queue_capacity, num_batch_queue_threads,
                       prefetch_queue_capacity, data_augmentation_options):
  """Sets up reader, prefetcher and returns input queue.

  Args:
    batch_size_per_clone: batch size to use per clone.
    create_tensor_dict_fn: function to create tensor dictionary.
    batch_queue_capacity: maximum number of elements to store within a queue.
    num_batch_queue_threads: number of threads to use for batching.
    prefetch_queue_capacity: maximum capacity of the queue used to prefetch
                             assembled batches.
    data_augmentation_options: a list of tuples, where each tuple contains a
      data augmentation function and a dictionary containing arguments and their
      values (see preprocessor.py).

 Returns:
    input queue: a batcher.BatchQueue object holding enqueued tensor_dicts
      (which hold images, boxes and targets).  To get a batch of tensor_dicts,
      call input_queue.Dequeue().
  """
  # 读一条数据
  tensor_dict = create_tensor_dict_fn()

  # [height, width, 3] 变为 [1, height, width, 3]
  tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
      tensor_dict[fields.InputDataFields.image], 0)

  # [1, height, width, 3] 将类型转为  float32
  images = tensor_dict[fields.InputDataFields.image]
  float_images = tf.to_float(images)
  tensor_dict[fields.InputDataFields.image] = float_images

  # 是否包含 instance_masks
  include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
                            in tensor_dict)
  # 是否包含 keypoints
  include_keypoints = (fields.InputDataFields.groundtruth_keypoints
                       in tensor_dict)
  # multiclass_scores
  include_multiclass_scores = (fields.InputDataFields.multiclass_scores
                               in tensor_dict)

  # 预处理数据增强
  if data_augmentation_options:
    tensor_dict = preprocessor.preprocess(
        tensor_dict, data_augmentation_options,
        func_arg_map=preprocessor.get_default_func_arg_map(
            include_multiclass_scores=include_multiclass_scores,
            include_instance_masks=include_instance_masks,
            include_keypoints=include_keypoints))

  # 创建两个对队列,
  # 队列1 : 开启 num_batch_queue_threads 个线程,每个线程从数据集依次读一条数据,写入队列。
  # 一个线程从队列每次读 batch_size 条数据
  # 队列2 : 从队列 1 出队列的数据写入队列 2,容量为 prefetch_queue_capacity。当调用 dequeue 的时候,从队列 2 读取 batch_size 的数据。
  input_queue = batcher.BatchQueue(
      tensor_dict,
      batch_size=batch_size_per_clone,
      batch_queue_capacity=batch_queue_capacity,
      num_batch_queue_threads=num_batch_queue_threads,
      prefetch_queue_capacity=prefetch_queue_capacity)
  return input_queue


def get_inputs(input_queue,
               num_classes,
               merge_multiple_label_boxes=False,
               use_multiclass_scores=False):
  """Dequeues batch and constructs inputs to object detection model.

  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    num_classes: Number of classes.
    merge_multiple_label_boxes: Whether to merge boxes with multiple labels
      or not. Defaults to false. Merged boxes are represented with a single
      box and a k-hot encoding of the multiple labels associated with the
      boxes.
    use_multiclass_scores: Whether to use multiclass scores instead of
      groundtruth_classes.

  Returns:
    images: a list of 3-D float tensor of images.
    image_keys: a list of string keys for the images.
    locations_list: a list of tensors of shape [num_boxes, 4]
      containing the corners of the groundtruth boxes.
    classes_list: a list of padded one-hot tensors containing target classes.
    masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
      image_width] containing instance masks for objects if present in the
      input_queue. Else returns None.
    keypoints_list: a list of 3-D float tensors of shape [num_boxes,
      num_keypoints, 2] containing keypoints for objects if present in the
      input queue. Else returns None.
    weights_lists: a list of 1-D float32 tensors of shape [num_boxes]
      containing groundtruth weight for each box.
  """

  # 从预提取队列取一份数据。[batch_size, height, width, 1]
  read_data_list = input_queue.dequeue()
  label_id_offset = 1

  # 解析读到的数据
  def extract_images_and_targets(read_data):
    """Extract images and targets from the input dict."""
    image = read_data[fields.InputDataFields.image]
    key = ''
    if fields.InputDataFields.source_id in read_data:
      key = read_data[fields.InputDataFields.source_id]
    location_gt = read_data[fields.InputDataFields.groundtruth_boxes]
    classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
                         tf.int32)
    classes_gt -= label_id_offset

    if merge_multiple_label_boxes and use_multiclass_scores:
      raise ValueError(
          'Using both merge_multiple_label_boxes and use_multiclass_scores is'
          'not supported'
      )

    if merge_multiple_label_boxes:
      location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
          location_gt, classes_gt, num_classes)
    elif use_multiclass_scores:
      classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores],
                           tf.float32)
    else:
      classes_gt = util_ops.padded_one_hot_encoding(
          indices=classes_gt, depth=num_classes, left_pad=0)
    masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks)
    keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints)
    if (merge_multiple_label_boxes and (
        masks_gt is not None or keypoints_gt is not None)):
      raise NotImplementedError('Multi-label support is only for boxes.')
    weights_gt = read_data.get(
        fields.InputDataFields.groundtruth_weights)
    return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt,
            weights_gt)

  return zip(*map(extract_images_and_targets, read_data_list))


def _create_losses(input_queue, create_model_fn, train_config):
  """Creates loss function for a DetectionModel.

  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    create_model_fn: A function to create the DetectionModel.
    train_config: a train_pb2.TrainConfig protobuf.
  """
  detection_model = create_model_fn()
  # 从预提取队列读取数据
  (images, _, groundtruth_boxes_list, groundtruth_classes_list,
   groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
       input_queue,
       detection_model.num_classes,
       train_config.merge_multiple_label_boxes,
       train_config.use_multiclass_scores)

  preprocessed_images = []
  true_image_shapes = []
  for image in images:
    # 对每张图片做模型的预处理(一般为 resize)
    resized_image, true_image_shape = detection_model.preprocess(image)
    preprocessed_images.append(resized_image)
    true_image_shapes.append(true_image_shape)

  images = tf.concat(preprocessed_images, 0)
  true_image_shapes = tf.concat(true_image_shapes, 0)

  if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None
  if any(keypoints is None for keypoints in groundtruth_keypoints_list):
    groundtruth_keypoints_list = None

  detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
                                      groundtruth_masks_list,
                                      groundtruth_keypoints_list)
  # 将图片喂给预测器
  prediction_dict = detection_model.predict(images, true_image_shapes)
  # 计算 loss
  losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
  for loss_tensor in losses_dict.values():
    tf.losses.add_loss(loss_tensor)


def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

    # .........
    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)
    # ...........

整个预处理流程

  1. 批量读数据前,通过 data_augmentation_options 类指定预处理操作
  2. 批量数据读取:创建两个队列
    队列1 : 开启 N 个线程,每个线程从数据集依次读一条数据,写入队列 1。一个线程从队列 1 每次读 batch_size 条数据
    队列2:将队列 1 出队列的数据写入队列 2, 当调用 dequeue 的事实,从队列 2 读取 batch_size 的数据。
  3. 批量读数据后,通过 模型的预处理函数进行预处理 detection_model.preprocess 之后,喂给模型。

实际的预处理包括

  1. 批量读数据前,通过 data_augmentation_options 类指定预处理操作
  2. 批量读数据后,通过 模型的预处理函数进行预处理 detection_model.preprocess

批量数据读取:创建两个队列

  1. 队列1 : 开启 N 个线程,每个线程从数据集依次读一条数据,写入队列 1。一个线程从队列 1 每次读 batch_size 条数据
  2. 队列2:将队列 1 出队列的数据写入队列 2, 当调用 dequeue 的事实,从队列 2 读取 batch_size 的数据。

目前支持的预处理操作

  1. normalize_image
  2. random_pixel_value_scale
  3. random_image_scale
  4. random_rgb_to_gray
  5. random_adjust_brightness
  6. random_adjust_contrast
  7. random_adjust_hue
  8. random_adjust_saturation
  9. random_distort_color
  10. random_jitter_boxes
  11. random_crop_to_aspect_ratio
  12. random_black_patches
  13. rgb_to_gray
  14. scale_boxes_to_pixel_coordinates
  15. subtract_channel_mean
  16. random_horizontal_flip
  17. random_vertical_flip
  18. random_rotation90
  19. random_pad_image
  20. random_crop_pad_image
  21. random_resize_method
  22. resize_image
  23. ssd_random_crop
  24. ssd_random_crop_pad
  25. ssd_random_crop_fixed_aspect_ratio
  26. ssd_random_crop_pad_fixed_aspect_ratio

所有预处理函数的实现参考 core/preprocessor.py

附录

builders/preprocessor_builder.py

import tensorflow as tf

from object_detection.core import preprocessor
from object_detection.protos import preprocessor_pb2


def _get_step_config_from_proto(preprocessor_step_config, step_name):
  """Returns the value of a field named step_name from proto.

  Args:
    preprocessor_step_config: A preprocessor_pb2.PreprocessingStep object.
    step_name: Name of the field to get value from.

  Returns:
    result_dict: a sub proto message from preprocessor_step_config which will be
                 later converted to a dictionary.

  Raises:
    ValueError: If field does not exist in proto.
  """
  for field, value in preprocessor_step_config.ListFields():
    if field.name == step_name:
      return value

  raise ValueError('Could not get field %s from proto!', step_name)


def _get_dict_from_proto(config):
  """Helper function to put all proto fields into a dictionary.

  For many preprocessing steps, there's an trivial 1-1 mapping from proto fields
  to function arguments. This function automatically populates a dictionary with
  the arguments from the proto.

  Protos that CANNOT be trivially populated include:
  * nested messages.
  * steps that check if an optional field is set (ie. where None != 0).
  * protos that don't map 1-1 to arguments (ie. list should be reshaped).
  * fields requiring additional validation (ie. repeated field has n elements).

  Args:
    config: A protobuf object that does not violate the conditions above.

  Returns:
    result_dict: |config| converted into a python dictionary.
  """
  result_dict = {}
  for field, value in config.ListFields():
    result_dict[field.name] = value
  return result_dict


# A map from a PreprocessingStep proto config field name to the preprocessing
# function that should be used. The PreprocessingStep proto should be parsable
# with _get_dict_from_proto.
PREPROCESSING_FUNCTION_MAP = {
    'normalize_image': preprocessor.normalize_image,
    'random_pixel_value_scale': preprocessor.random_pixel_value_scale,
    'random_image_scale': preprocessor.random_image_scale,
    'random_rgb_to_gray': preprocessor.random_rgb_to_gray,
    'random_adjust_brightness': preprocessor.random_adjust_brightness,
    'random_adjust_contrast': preprocessor.random_adjust_contrast,
    'random_adjust_hue': preprocessor.random_adjust_hue,
    'random_adjust_saturation': preprocessor.random_adjust_saturation,
    'random_distort_color': preprocessor.random_distort_color,
    'random_jitter_boxes': preprocessor.random_jitter_boxes,
    'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio,
    'random_black_patches': preprocessor.random_black_patches,
    'rgb_to_gray': preprocessor.rgb_to_gray,
    'scale_boxes_to_pixel_coordinates': (
        preprocessor.scale_boxes_to_pixel_coordinates),
    'subtract_channel_mean': preprocessor.subtract_channel_mean,
}


# A map to convert from preprocessor_pb2.ResizeImage.Method enum to
# tf.image.ResizeMethod.
RESIZE_METHOD_MAP = {
    preprocessor_pb2.ResizeImage.AREA: tf.image.ResizeMethod.AREA,
    preprocessor_pb2.ResizeImage.BICUBIC: tf.image.ResizeMethod.BICUBIC,
    preprocessor_pb2.ResizeImage.BILINEAR: tf.image.ResizeMethod.BILINEAR,
    preprocessor_pb2.ResizeImage.NEAREST_NEIGHBOR: (
        tf.image.ResizeMethod.NEAREST_NEIGHBOR),
}

def build(preprocessor_step_config):
  """Builds preprocessing step based on the configuration.

  Args:
    preprocessor_step_config: PreprocessingStep configuration proto.

  Returns:
    function, argmap: A callable function and an argument map to call function
                      with.

  Raises:
    ValueError: On invalid configuration.
  """
  step_type = preprocessor_step_config.WhichOneof('preprocessing_step')

  if step_type in PREPROCESSING_FUNCTION_MAP:
    preprocessing_function = PREPROCESSING_FUNCTION_MAP[step_type]
    step_config = _get_step_config_from_proto(preprocessor_step_config,
                                              step_type)
    function_args = _get_dict_from_proto(step_config)
    return (preprocessing_function, function_args)

  if step_type == 'random_horizontal_flip':
    config = preprocessor_step_config.random_horizontal_flip
    return (preprocessor.random_horizontal_flip,
            {
                'keypoint_flip_permutation': tuple(
                    config.keypoint_flip_permutation),
            })

  if step_type == 'random_vertical_flip':
    config = preprocessor_step_config.random_vertical_flip
    return (preprocessor.random_vertical_flip,
            {
                'keypoint_flip_permutation': tuple(
                    config.keypoint_flip_permutation),
            })

  if step_type == 'random_rotation90':
    return (preprocessor.random_rotation90, {})

  if step_type == 'random_crop_image':
    config = preprocessor_step_config.random_crop_image
    return (preprocessor.random_crop_image,
            {
                'min_object_covered': config.min_object_covered,
                'aspect_ratio_range': (config.min_aspect_ratio,
                                       config.max_aspect_ratio),
                'area_range': (config.min_area, config.max_area),
                'overlap_thresh': config.overlap_thresh,
                'random_coef': config.random_coef,
            })

  if step_type == 'random_pad_image':
    config = preprocessor_step_config.random_pad_image
    min_image_size = None
    if (config.HasField('min_image_height') !=
        config.HasField('min_image_width')):
      raise ValueError('min_image_height and min_image_width should be either '
                       'both set or both unset.')
    if config.HasField('min_image_height'):
      min_image_size = (config.min_image_height, config.min_image_width)

    max_image_size = None
    if (config.HasField('max_image_height') !=
        config.HasField('max_image_width')):
      raise ValueError('max_image_height and max_image_width should be either '
                       'both set or both unset.')
    if config.HasField('max_image_height'):
      max_image_size = (config.max_image_height, config.max_image_width)

    pad_color = config.pad_color
    if pad_color and len(pad_color) != 3:
      raise ValueError('pad_color should have 3 elements (RGB) if set!')
    if not pad_color:
      pad_color = None
    return (preprocessor.random_pad_image,
            {
                'min_image_size': min_image_size,
                'max_image_size': max_image_size,
                'pad_color': pad_color,
            })

  if step_type == 'random_crop_pad_image':
    config = preprocessor_step_config.random_crop_pad_image
    min_padded_size_ratio = config.min_padded_size_ratio
    if min_padded_size_ratio and len(min_padded_size_ratio) != 2:
      raise ValueError('min_padded_size_ratio should have 2 elements if set!')
    max_padded_size_ratio = config.max_padded_size_ratio
    if max_padded_size_ratio and len(max_padded_size_ratio) != 2:
      raise ValueError('max_padded_size_ratio should have 2 elements if set!')
    pad_color = config.pad_color
    if pad_color and len(pad_color) != 3:
      raise ValueError('pad_color should have 3 elements if set!')
    kwargs = {
        'min_object_covered': config.min_object_covered,
        'aspect_ratio_range': (config.min_aspect_ratio,
                               config.max_aspect_ratio),
        'area_range': (config.min_area, config.max_area),
        'overlap_thresh': config.overlap_thresh,
        'random_coef': config.random_coef,
    }
    if min_padded_size_ratio:
      kwargs['min_padded_size_ratio'] = tuple(min_padded_size_ratio)
    if max_padded_size_ratio:
      kwargs['max_padded_size_ratio'] = tuple(max_padded_size_ratio)
    if pad_color:
      kwargs['pad_color'] = tuple(pad_color)
    return (preprocessor.random_crop_pad_image, kwargs)

  if step_type == 'random_resize_method':
    config = preprocessor_step_config.random_resize_method
    return (preprocessor.random_resize_method,
            {
                'target_size': [config.target_height, config.target_width],
            })

  if step_type == 'resize_image':
    config = preprocessor_step_config.resize_image
    method = RESIZE_METHOD_MAP[config.method]
    return (preprocessor.resize_image,
            {
                'new_height': config.new_height,
                'new_width': config.new_width,
                'method': method
            })

  if step_type == 'ssd_random_crop':
    config = preprocessor_step_config.ssd_random_crop
    if config.operations:
      min_object_covered = [op.min_object_covered for op in config.operations]
      aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
                            for op in config.operations]
      area_range = [(op.min_area, op.max_area) for op in config.operations]
      overlap_thresh = [op.overlap_thresh for op in config.operations]
      random_coef = [op.random_coef for op in config.operations]
      return (preprocessor.ssd_random_crop,
              {
                  'min_object_covered': min_object_covered,
                  'aspect_ratio_range': aspect_ratio_range,
                  'area_range': area_range,
                  'overlap_thresh': overlap_thresh,
                  'random_coef': random_coef,
              })
    return (preprocessor.ssd_random_crop, {})

  if step_type == 'ssd_random_crop_pad':
    config = preprocessor_step_config.ssd_random_crop_pad
    if config.operations:
      min_object_covered = [op.min_object_covered for op in config.operations]
      aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
                            for op in config.operations]
      area_range = [(op.min_area, op.max_area) for op in config.operations]
      overlap_thresh = [op.overlap_thresh for op in config.operations]
      random_coef = [op.random_coef for op in config.operations]
      min_padded_size_ratio = [tuple(op.min_padded_size_ratio)
                               for op in config.operations]
      max_padded_size_ratio = [tuple(op.max_padded_size_ratio)
                               for op in config.operations]
      pad_color = [(op.pad_color_r, op.pad_color_g, op.pad_color_b)
                   for op in config.operations]
      return (preprocessor.ssd_random_crop_pad,
              {
                  'min_object_covered': min_object_covered,
                  'aspect_ratio_range': aspect_ratio_range,
                  'area_range': area_range,
                  'overlap_thresh': overlap_thresh,
                  'random_coef': random_coef,
                  'min_padded_size_ratio': min_padded_size_ratio,
                  'max_padded_size_ratio': max_padded_size_ratio,
                  'pad_color': pad_color,
              })
    return (preprocessor.ssd_random_crop_pad, {})

  if step_type == 'ssd_random_crop_fixed_aspect_ratio':
    config = preprocessor_step_config.ssd_random_crop_fixed_aspect_ratio
    if config.operations:
      min_object_covered = [op.min_object_covered for op in config.operations]
      area_range = [(op.min_area, op.max_area) for op in config.operations]
      overlap_thresh = [op.overlap_thresh for op in config.operations]
      random_coef = [op.random_coef for op in config.operations]
      return (preprocessor.ssd_random_crop_fixed_aspect_ratio,
              {
                  'min_object_covered': min_object_covered,
                  'aspect_ratio': config.aspect_ratio,
                  'area_range': area_range,
                  'overlap_thresh': overlap_thresh,
                  'random_coef': random_coef,
              })
    return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})

  if step_type == 'ssd_random_crop_pad_fixed_aspect_ratio':
    config = preprocessor_step_config.ssd_random_crop_pad_fixed_aspect_ratio
    kwargs = {}
    aspect_ratio = config.aspect_ratio
    if aspect_ratio:
      kwargs['aspect_ratio'] = aspect_ratio
    min_padded_size_ratio = config.min_padded_size_ratio
    if min_padded_size_ratio:
      if len(min_padded_size_ratio) != 2:
        raise ValueError('min_padded_size_ratio should have 2 elements if set!')
      kwargs['min_padded_size_ratio'] = tuple(min_padded_size_ratio)
    max_padded_size_ratio = config.max_padded_size_ratio
    if max_padded_size_ratio:
      if len(max_padded_size_ratio) != 2:
        raise ValueError('max_padded_size_ratio should have 2 elements if set!')
      kwargs['max_padded_size_ratio'] = tuple(max_padded_size_ratio)
    if config.operations:
      kwargs['min_object_covered'] = [op.min_object_covered
                                      for op in config.operations]
      kwargs['aspect_ratio_range'] = [(op.min_aspect_ratio, op.max_aspect_ratio)
                                      for op in config.operations]
      kwargs['area_range'] = [(op.min_area, op.max_area)
                              for op in config.operations]
      kwargs['overlap_thresh'] = [op.overlap_thresh for op in config.operations]
      kwargs['random_coef'] = [op.random_coef for op in config.operations]
    return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs)

  raise ValueError('Unknown preprocessing step.')
"""Provides functions to batch a dictionary of input tensors."""
import collections

import tensorflow as tf

from object_detection.core import prefetcher

rt_shape_str = '_runtime_shapes'


# 创建两个对队列,
# 队列1 : 开启 num_batch_queue_threads 个线程,每个线程从数据集依次读一条数据,写入队列。
# 一个线程从队列每次读 batch_size 条数据
# 队列2 : 将队列 1 出队列的数据写入队列 2,容量为 prefetch_queue_capacity。当调用 dequeue 的时候,从队列 2 读取 batch_size 的数据。队列 2 中每个元素的容量为 [batch_size, height, width, 3]

class BatchQueue(object):
  """BatchQueue class.

  This class creates a batch queue to asynchronously enqueue tensors_dict.
  It also adds a FIFO prefetcher so that the batches are readily available
  for the consumers.  Dequeue ops for a BatchQueue object can be created via
  the Dequeue method which evaluates to a batch of tensor_dict.

  Example input pipeline with batching:
  ------------------------------------
  key, string_tensor = slim.parallel_reader.parallel_read(...)
  tensor_dict = decoder.decode(string_tensor)
  tensor_dict = preprocessor.preprocess(tensor_dict, ...)
  batch_queue = batcher.BatchQueue(tensor_dict,
                                   batch_size=32,
                                   batch_queue_capacity=2000,
                                   num_batch_queue_threads=8,
                                   prefetch_queue_capacity=20)
  tensor_dict = batch_queue.dequeue()
  outputs = Model(tensor_dict)
  ...
  -----------------------------------

  Notes:
  -----
  This class batches tensors of unequal sizes by zero padding and unpadding
  them after generating a batch. This can be computationally expensive when
  batching tensors (such as images) that are of vastly different sizes. So it is
  recommended that the shapes of such tensors be fully defined in tensor_dict
  while other lightweight tensors such as bounding box corners and class labels
  can be of varying sizes. Use either crop or resize operations to fully define
  the shape of an image in tensor_dict.

  It is also recommended to perform any preprocessing operations on tensors
  before passing to BatchQueue and subsequently calling the Dequeue method.

  Another caveat is that this class does not read the last batch if it is not
  full. The current implementation makes it hard to support that use case. So,
  for evaluation, when it is critical to run all the examples through your
  network use the input pipeline example mentioned in core/prefetcher.py.
  """

  def __init__(self, tensor_dict, batch_size, batch_queue_capacity,
               num_batch_queue_threads, prefetch_queue_capacity):
    """Constructs a batch queue holding tensor_dict.

    Args:
      tensor_dict: dictionary of tensors to batch.
      batch_size: batch size.
      batch_queue_capacity: max capacity of the queue from which the tensors are
        batched.
      num_batch_queue_threads: number of threads to use for batching.
      prefetch_queue_capacity: max capacity of the queue used to prefetch
        assembled batches.
    """

    # Remember static shapes to set shapes of batched tensors.
    static_shapes = collections.OrderedDict(
        {key: tensor.get_shape() for key, tensor in tensor_dict.items()})
    # Remember runtime shapes to unpad tensors after batching.
    runtime_shapes = collections.OrderedDict(
        {(key + rt_shape_str): tf.shape(tensor)
         for key, tensor in tensor_dict.items()})

    all_tensors = tensor_dict
    all_tensors.update(runtime_shapes)

    # 创建一个 PaddingFIFOQueue 队列,容量为 batch_queue_capacity,
    # 开启 num_batch_queue_threads 线程, 每个线程一次从 all_tensors 读一条数据写入,一直到队列满为止
    # 从队列每次读 batch_size 条数据,并返回
    batched_tensors = tf.train.batch(
        all_tensors,
        capacity=batch_queue_capacity,
        batch_size=batch_size,
        dynamic_pad=True,
        num_threads=num_batch_queue_threads)

    # 创建容量为 prefetch_queue_capacity 的 PaddingFIFOQueue
    # 创建一个线程将 batched_tensors 加入队列。
    self._queue = prefetcher.prefetch(batched_tensors,
                                      prefetch_queue_capacity)
    self._static_shapes = static_shapes
    self._batch_size = batch_size

  def dequeue(self):
    """Dequeues a batch of tensor_dict from the BatchQueue.

    TODO: use allow_smaller_final_batch to allow running over the whole eval set

    Returns:
      A list of tensor_dicts of the requested batch_size.
    """
    # 从 prefetch 队列取一份数据 (size 为 batch_size)
    batched_tensors = self._queue.dequeue()
    # Separate input tensors from tensors containing their runtime shapes.
    tensors = {}
    shapes = {}
    for key, batched_tensor in batched_tensors.items():
      unbatched_tensor_list = tf.unstack(batched_tensor)
      for i, unbatched_tensor in enumerate(unbatched_tensor_list):
        if rt_shape_str in key:
          shapes[(key[:-len(rt_shape_str)], i)] = unbatched_tensor
        else:
          tensors[(key, i)] = unbatched_tensor

    # Undo that padding using shapes and create a list of size `batch_size` that
    # contains tensor dictionaries.
    tensor_dict_list = []
    batch_size = self._batch_size
    for batch_id in range(batch_size):
      tensor_dict = {}
      for key in self._static_shapes:
        tensor_dict[key] = tf.slice(tensors[(key, batch_id)],
                                    tf.zeros_like(shapes[(key, batch_id)]),
                                    shapes[(key, batch_id)])
        tensor_dict[key].set_shape(self._static_shapes[key])
      tensor_dict_list.append(tensor_dict)

    return tensor_dict_list

问题:这里 static_shapes 和 runtime_shapes 的目的何在?

prefetch.py

import tensorflow as tf


def prefetch(tensor_dict, capacity):
  """Creates a prefetch queue for tensors.

  Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a
  dequeue op that evaluates to a tensor_dict. This function is useful in
  prefetching preprocessed tensors so that the data is readily available for
  consumers.

  Example input pipeline when you don't need batching:
  ----------------------------------------------------
  key, string_tensor = slim.parallel_reader.parallel_read(...)
  tensor_dict = decoder.decode(string_tensor)
  tensor_dict = preprocessor.preprocess(tensor_dict, ...)
  prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20)
  tensor_dict = prefetch_queue.dequeue()
  outputs = Model(tensor_dict)
  ...
  ----------------------------------------------------

  For input pipelines with batching, refer to core/batcher.py

  Args:
    tensor_dict: a dictionary of tensors to prefetch.
    capacity: the size of the prefetch queue.

  Returns:
    a FIFO prefetcher queue
  """
  names = list(tensor_dict.keys())
  dtypes = [t.dtype for t in tensor_dict.values()]
  shapes = [t.get_shape() for t in tensor_dict.values()]
  prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
                                       shapes=shapes,
                                       names=names,
                                       name='prefetch_queue')
  enqueue_op = prefetch_queue.enqueue(tensor_dict)
  tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
      prefetch_queue, [enqueue_op]))
  tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name,
                                                      capacity),
                    tf.to_float(prefetch_queue.size()) * (1. / capacity))
  return prefetch_queue
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值