本文主要分析 models/research/object_detection/train.py
核心代码
其中配置示例 samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
读数据 train.py
input_config = configs['train_input_config']
def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(config)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
整个流程
- 初始化需要解码的字段,以及解码对应字段的 handler
- 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 解码数据,预提取input_reader_config.prefetch_size 条数据
- 对数据集应用 tf.contrib.data.padded_batch_and_drop_remainder,如果不够一个 batch_size 就丢弃该部分数据
- 返回一个迭代器
Tips: 发现把所有依赖放在一个文件里面非常方便阅读与分析。
附录
dataset_util.py
def make_initializable_iterator(dataset):
"""Creates an iterator, and initializes tables.
This is useful in cases where make_one_shot_iterator wouldn't work because
the graph contains a hash table that needs to be initialized.
Args:
dataset: A `tf.data.Dataset` object.
Returns:
A `tf.data.Iterator`.
"""
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
return iterator
# 调用 file_read_func 从 input_files 中读数据
# 调用 decode_func 将读到的数据解码
# 预提取 config.prefetch_size
def read_dataset(file_read_func, decode_func, input_files, config):
"""Reads a dataset, and handles repetition and shuffling.
Args:
file_read_func: Function to use in tf.data.Dataset.interleave, to read
every individual file into a tf.data.Dataset.
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
Returns:
A tf.data.Dataset based on config.
"""
# Shard, shuffle, and read files.
filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
0)
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if config.shuffle:
filename_dataset = filename_dataset.shuffle(
config.filenames_shuffle_buffer_size)
elif config.num_readers > 1:
tf.logging.warning('`shuffle` is false, but the input data stream is '
'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func, cycle_length=config.num_readers,
block_length=config.read_block_length, sloppy=config.shuffle))
if config.shuffle:
records_dataset.shuffle(config.shuffle_buffer_size)
tensor_dataset = records_dataset.map(
decode_func, num_parallel_calls=config.num_parallel_map_calls)
return tensor_dataset.prefetch(config.prefetch_size)
dataset_builder.py
# 返回 dataset.output_shapes 各个 key 对应的 shape
def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Returns shapes to pad dataset tensors to before batching.
Args:
dataset: tf.data.Dataset object.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding.
num_classes: Number of classes in the dataset needed to compute shapes for
padding.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image.
Returns:
A dictionary keyed by fields.InputDataFields containing padding shapes for
tensors in the dataset.
Raises:
ValueError: If groundtruth classes is neither rank 1 nor rank 2.
"""
if not spatial_image_shape or spatial_image_shape == [-1, -1]:
height, width = None, None
else:
height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
padding_shapes = {
fields.InputDataFields.image: [height, width, 3],
fields.InputDataFields.source_id: [],
fields.InputDataFields.filename: [],
fields.InputDataFields.key: [],
fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
width],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_area: [max_num_boxes],
fields.InputDataFields.groundtruth_weights: [max_num_boxes],
fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.multiclass_scores: [
max_num_boxes, num_classes + 1 if num_classes is not None else None],
}
# Determine whether groundtruth_classes are integers or one-hot encodings, and
# apply batching appropriately.
classes_shape = dataset.output_shapes[
fields.InputDataFields.groundtruth_classes]
if len(classes_shape) == 1: # Class integers.
padding_shapes[fields.InputDataFields.groundtruth_classes] = [max_num_boxes]
elif len(classes_shape) == 2: # One-hot or k-hot encoding.
padding_shapes[fields.InputDataFields.groundtruth_classes] = [
max_num_boxes, num_classes]
else:
raise ValueError('Groundtruth classes must be a rank 1 tensor (classes) or '
'rank 2 tensor (one-hot encodings)')
if fields.InputDataFields.original_image in dataset.output_shapes:
padding_shapes[fields.InputDataFields.original_image] = [None, None, 3]
if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoints]
padding_shape = [max_num_boxes, tensor_shape[1].value,
tensor_shape[2].value]
padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
if (fields.InputDataFields.groundtruth_keypoint_visibilities
in dataset.output_shapes):
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities]
padding_shape = [max_num_boxes, tensor_shape[1].value]
padding_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities] = padding_shape
return {tensor_key: padding_shapes[tensor_key]
for tensor_key, _ in dataset.output_shapes.items()}
def build(input_reader_config, transform_input_data_fn=None,
batch_size=None, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
records. Applies a padded batch to the resulting dataset.
Args:
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
batch_size: Batch size. If None, batching is not performed.
max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
padding. If None, will use a dynamic shape.
num_classes: Number of classes in the dataset needed to compute shapes for
padding. If None, will use a dynamic shape.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. If None, will use dynamic shapes.
Returns:
A tf.data.Dataset based on the input_reader_config.
Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
# 必须是 input_reader_pb2.InputReader 类型对象
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
# input_reader_config 的 input_reader 为 tf_record_input_reader
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
# 此时 config 为 TFRecordInputReader
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path
# 初始化需要解码的字段,以及解码对应字段的 handler
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file)
def process_fn(value):
processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed
# 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 解码
# 数据,预提取 input_reader_config.prefetch_size 条数据
dataset = dataset_util.read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config)
if batch_size:
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
# 对数据集应用 tf.contrib.data.padded_batch_and_drop_remainder
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset
raise ValueError('Unsupported input_reader_config.')
input_reader.proto
yntax = "proto2";
package object_detection.protos;
// Configuration proto for defining input readers that generate Object Detection
// Examples from input sources. Input readers are expected to generate a
// dictionary of tensors, with the following fields populated:
//
// 'image': an [image_height, image_width, channels] image tensor that detection
// will be run on.
// 'groundtruth_classes': a [num_boxes] int32 tensor storing the class
// labels of detected boxes in the image.
// 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of
// detected boxes in the image.
// 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height,
// image_width] float tensor storing binary mask of the objects in boxes.
// Instance mask format. Note that PNG masks are much more space efficient.
enum InstanceMaskType {
DEFAULT = 0; // Default implementation, currently NUMERICAL_MASKS
NUMERICAL_MASKS = 1; // [num_masks, H, W] float32 binary masks.
PNG_MASKS = 2; // Encoded PNG masks.
}
message InputReader {
// Path to StringIntLabelMap pbtxt file specifying the mapping from string
// labels to integer ids.
optional string label_map_path = 1 [default=""];
// Whether data should be processed in the order they are read in, or
// shuffled randomly.
optional bool shuffle = 2 [default=true];
// Buffer size to be used when shuffling.
optional uint32 shuffle_buffer_size = 11 [default = 2048];
// Buffer size to be used when shuffling file names.
optional uint32 filenames_shuffle_buffer_size = 12 [default = 100];
// Maximum number of records to keep in reader queue.
optional uint32 queue_capacity = 3 [default=2000];
// Minimum number of records to keep in reader queue. A large value is needed
// to generate a good random shuffle.
optional uint32 min_after_dequeue = 4 [default=1000];
// The number of times a data source is read. If set to zero, the data source
// will be reused indefinitely.
optional uint32 num_epochs = 5 [default=0];
// Number of reader instances to create.
optional uint32 num_readers = 6 [default=32];
// Number of records to read from each reader at once.
optional uint32 read_block_length = 15 [default=32];
// Number of decoded records to prefetch before batching.
optional uint32 prefetch_size = 13 [default = 512];
// Number of parallel decode ops to apply.
optional uint32 num_parallel_map_calls = 14 [default = 64];
// Number of groundtruth keypoints per object.
optional uint32 num_keypoints = 16 [default = 0];
// Whether to load groundtruth instance masks.
optional bool load_instance_masks = 7 [default = false];
// Type of instance mask.
optional InstanceMaskType mask_type = 10 [default = NUMERICAL_MASKS];
oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9;
}
}
// An input reader that reads TF Example protos from local TFRecord files.
message TFRecordInputReader {
// Path(s) to `TFRecordFile`s.
repeated string input_path = 1;
}
// An externally defined input reader. Users may define an extension to this
// proto to interface their own input readers.
message ExternalInputReader {
extensions 1 to 999;
}
tf_example_decoder.py
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util
slim_example_decoder = tf.contrib.slim.tfexample_decoder
# TODO(lzc): keep LookupTensor and BackupHandler in sync with
# tf.contrib.slim.tfexample_decoder version.
class LookupTensor(slim_example_decoder.Tensor):
"""An ItemHandler that returns a parsed Tensor, the result of a lookup."""
def __init__(self,
tensor_key,
table,
shape_keys=None,
shape=None,
default_value=''):
"""Initializes the LookupTensor handler.
Simply calls a vocabulary (most often, a label mapping) lookup.
Args:
tensor_key: the name of the `TFExample` feature to read the tensor from.
table: A tf.lookup table.
shape_keys: Optional name or list of names of the TF-Example feature in
which the tensor shape is stored. If a list, then each corresponds to
one dimension of the shape.
shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
reshaped accordingly.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if both `shape_keys` and `shape` are specified.
"""
self._table = table
super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
default_value)
def tensors_to_item(self, keys_to_tensors):
unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
return self._table.lookup(unmapped_tensor)
class BackupHandler(slim_example_decoder.ItemHandler):
"""An ItemHandler that tries two ItemHandlers in order."""
def __init__(self, handler, backup):
"""Initializes the BackupHandler handler.
If the first Handler's tensors_to_item returns a Tensor with no elements,
the second Handler is used.
Args:
handler: The primary ItemHandler.
backup: The backup ItemHandler.
Raises:
ValueError: if either is not an ItemHandler.
"""
if not isinstance(handler, slim_example_decoder.ItemHandler):
raise ValueError('Primary handler is of type %s instead of ItemHandler' %
type(handler))
if not isinstance(backup, slim_example_decoder.ItemHandler):
raise ValueError(
'Backup handler is of type %s instead of ItemHandler' % type(backup))
self._handler = handler
self._backup = backup
super(BackupHandler, self).__init__(handler.keys + backup.keys)
def tensors_to_item(self, keys_to_tensors):
item = self._handler.tensors_to_item(keys_to_tensors)
return control_flow_ops.cond(
pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
false_fn=lambda: item)
# 初始化指定需要解码哪些字段,以及对应的 handler
# decode 方法调用 parsing_ops.parse_single_example 解码对应的字段,调用对应字段的 hander 解码
class TfExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Example proto decoder."""
# 初始化各种标签的 handler
def __init__(self,
load_instance_masks=False,
instance_mask_type=input_reader_pb2.NUMERICAL_MASKS,
label_map_proto_file=None,
use_display_name=False,
dct_method='',
num_keypoints=0):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
load_instance_masks: whether or not to load and handle instance masks.
instance_mask_type: type of instance masks. Options are provided in
input_reader.proto. This is only used if `load_instance_masks` is True.
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. If provided, then the
mapped IDs of 'image/object/class/text' will take precedence over the
existing 'image/object/class/label' ID. Also, if provided, it is
assumed that 'image/object/class/text' will be in the data.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
dct_method: An optional string. Defaults to None. It only takes
effect when image format is jpeg, used to specify a hint about the
algorithm used for jpeg decompression. Currently valid values
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
num_keypoints: the number of keypoints per object.
Raises:
ValueError: If `instance_mask_type` option is not one of
input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or
input_reader_pb2.PNG_MASKS.
"""
self.keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, default_value=1),
'image/width':
tf.FixedLenFeature((), tf.int64, default_value=1),
# Object boxes and classes.
'image/object/bbox/xmin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax':
tf.VarLenFeature(tf.float32),
'image/object/class/label':
tf.VarLenFeature(tf.int64),
'image/object/class/text':
tf.VarLenFeature(tf.string),
'image/object/area':
tf.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.VarLenFeature(tf.int64),
'image/object/difficult':
tf.VarLenFeature(tf.int64),
'image/object/group_of':
tf.VarLenFeature(tf.int64),
'image/object/weight':
tf.VarLenFeature(tf.float32),
}
#如果是 JPEG 图片,可以加速。
if dct_method:
# image 对象主要用来解码图片,
# 如果 image/format 为 raw 或 Raw 调用 parsing_ops.decode_raw 解码
# 如果 image/format 为 jpeg,调用 image_ops.decode_jpeg 解码
# 否则调用 image_ops.decode_image 解码图片
image = slim_example_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3,
dct_method=dct_method)
else:
# image 对象主要用来解码图片,
# 如果 image/format 为 raw 或 Raw 调用 parsing_ops.decode_raw 解码
# 如果 image/format 为 jpeg,调用 image_ops.decode_jpeg 解码
# 否则调用 image_ops.decode_image 解码图片
image = slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3)
self.items_to_handlers = {
fields.InputDataFields.image:
image,
fields.InputDataFields.source_id: (
slim_example_decoder.Tensor('image/source_id')),
fields.InputDataFields.key: (
slim_example_decoder.Tensor('image/key/sha256')),
fields.InputDataFields.filename: (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
'image/object/bbox/')),
fields.InputDataFields.groundtruth_area:
slim_example_decoder.Tensor('image/object/area'),
fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('image/object/is_crowd')),
fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('image/object/difficult')),
fields.InputDataFields.groundtruth_group_of: (
slim_example_decoder.Tensor('image/object/group_of')),
fields.InputDataFields.groundtruth_weights: (
slim_example_decoder.Tensor('image/object/weight')),
}
self._num_keypoints = num_keypoints
if num_keypoints > 0:
self.keys_to_features['image/object/keypoint/x'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/keypoint/y'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/y', 'image/object/keypoint/x'],
self._reshape_keypoints))
if load_instance_masks:
if instance_mask_type in (input_reader_pb2.DEFAULT,
input_reader_pb2.NUMERICAL_MASKS):
self.keys_to_features['image/object/mask'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
elif instance_mask_type == input_reader_pb2.PNG_MASKS:
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._decode_png_instance_masks))
else:
raise ValueError('Did not recognize the `instance_mask_type` option.')
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
# TODO(lzc): note that here we are using BackupHandler defined in this
# file(which is branching slim_example_decoder.BackupHandler). Need to
# switch back to slim_example_decoder.BackupHandler once tf 1.5 becomes
# more popular.
label_handler = BackupHandler(
LookupTensor('image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
def decode(self, tf_example_string_tensor):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: a string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary of the following tensors.
fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
containing image.
fields.InputDataFields.source_id - string tensor containing original
image id.
fields.InputDataFields.key - string tensor with unique sha256 hash key.
fields.InputDataFields.filename - string tensor with original dataset
filename.
fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
[None, 4] containing box corners.
fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
[None] containing classes for the boxes.
fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
shape [None] indicating the weights of groundtruth boxes.
fields.InputDataFields.num_groundtruth_boxes - int32 scalar indicating
the number of groundtruth_boxes.
fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
[None] containing containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
[None] indicating if the boxes represent `group_of` instances.
fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
shape [None, None, 2] containing keypoints, where the coordinates of
the keypoints are ordered (y, x).
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks.
"""
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
# 创建一个 TFExampleDecoder 对象,初始化需要解码的字段以及对应字段的 handler
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers)
keys = decoder.list_items()
# 调用 parsing_ops.parse_single_example 将 String 类型的 Tensor 进行解码。
# 需要解码的字段由 keys 指定
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
is_crowd = fields.InputDataFields.groundtruth_is_crowd
tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
def default_groundtruth_weights():
return tf.ones(
[tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
tf.greater(
tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights)
return tensor_dict
def _reshape_keypoints(self, keys_to_tensors):
"""Reshape keypoints.
The instance segmentation masks are reshaped to [num_instances,
num_keypoints, 2].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values
in {0, 1}.
"""
y = keys_to_tensors['image/object/keypoint/y']
if isinstance(y, tf.SparseTensor):
y = tf.sparse_tensor_to_dense(y)
y = tf.expand_dims(y, 1)
x = keys_to_tensors['image/object/keypoint/x']
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
x = tf.expand_dims(x, 1)
keypoints = tf.concat([y, x], 1)
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints
def _reshape_instance_masks(self, keys_to_tensors):
"""Reshape instance segmentation masks.
The instance segmentation masks are reshaped to [num_instances, height, width].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values in {0, 1}.
"""
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
masks = keys_to_tensors['image/object/mask']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
return tf.cast(masks, tf.float32)
def _decode_png_instance_masks(self, keys_to_tensors):
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height, width].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values in {0, 1}.
"""
def decode_png_mask(image_buffer):
image = tf.squeeze(
tf.image.decode_image(image_buffer, channels=1), axis=2)
image.set_shape([None, None])
image = tf.to_float(tf.greater(image, 0))
return image
png_masks = keys_to_tensors['image/object/mask']
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
if isinstance(png_masks, tf.SparseTensor):
png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
return tf.cond(
tf.greater(tf.size(png_masks), 0),
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))
tfexample_decoder.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from tensorflow.contrib.slim.python.slim.data import data_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
class ItemHandler(object):
"""Specifies the item-to-Features mapping for tf.parse_example.
An ItemHandler both specifies a list of Features used for parsing an Example
proto as well as a function that post-processes the results of Example
parsing.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, keys):
"""Constructs the handler with the name of the tf.Feature keys to use.
See third_party/tensorflow/core/example/feature.proto
Args:
keys: the name of the TensorFlow Example Feature.
"""
if not isinstance(keys, (tuple, list)):
keys = [keys]
self._keys = keys
@property
def keys(self):
return self._keys
@abc.abstractmethod
def tensors_to_item(self, keys_to_tensors):
"""Maps the given dictionary of tensors to the requested item.
Args:
keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
Returns:
the final tensor representing the item being handled.
"""
pass
class ItemHandlerCallback(ItemHandler):
"""An ItemHandler that converts the parsed tensors via a given function.
Unlike other ItemHandlers, the ItemHandlerCallback resolves its item via
a callback function rather than using prespecified behavior.
"""
def __init__(self, keys, func):
"""Initializes the ItemHandler.
Args:
keys: a list of TF-Example keys.
func: a function that takes as an argument a dictionary from `keys` to
parsed Tensors.
"""
super(ItemHandlerCallback, self).__init__(keys)
self._func = func
def tensors_to_item(self, keys_to_tensors):
return self._func(keys_to_tensors)
class BoundingBox(ItemHandler):
"""An ItemHandler that concatenates a set of parsed Tensors to Bounding Boxes.
"""
def __init__(self, keys=None, prefix=''):
"""Initialize the bounding box handler.
Args:
keys: A list of four key names representing the ymin, xmin, ymax, mmax
prefix: An optional prefix for each of the bounding box keys.
If provided, `prefix` is appended to each key in `keys`.
Raises:
ValueError: if keys is not `None` and also not a list of exactly 4 keys
"""
if keys is None:
keys = ['ymin', 'xmin', 'ymax', 'xmax']
elif len(keys) != 4:
raise ValueError('BoundingBox expects 4 keys but got {}'.format(
len(keys)))
self._prefix = prefix
self._keys = keys
self._full_keys = [prefix + k for k in keys]
super(BoundingBox, self).__init__(self._full_keys)
def tensors_to_item(self, keys_to_tensors):
"""Maps the given dictionary of tensors to a concatenated list of bboxes.
Args:
keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
Returns:
[num_boxes, 4] tensor of bounding box coordinates,
i.e. 1 bounding box per row, in order [y_min, x_min, y_max, x_max].
"""
sides = []
for key in self._full_keys:
side = keys_to_tensors[key]
if isinstance(side, sparse_tensor.SparseTensor):
side = side.values
side = array_ops.expand_dims(side, 0)
sides.append(side)
bounding_box = array_ops.concat(sides, 0)
return array_ops.transpose(bounding_box)
class Tensor(ItemHandler):
"""An ItemHandler that returns a parsed Tensor."""
def __init__(self, tensor_key, shape_keys=None, shape=None, default_value=0):
"""Initializes the Tensor handler.
Tensors are, by default, returned without any reshaping. However, there are
two mechanisms which allow reshaping to occur at load time. If `shape_keys`
is provided, both the `Tensor` corresponding to `tensor_key` and
`shape_keys` is loaded and the former `Tensor` is reshaped with the values
of the latter. Alternatively, if a fixed `shape` is provided, the `Tensor`
corresponding to `tensor_key` is loaded and reshape appropriately.
If neither `shape_keys` nor `shape` are provided, the `Tensor` will be
returned without any reshaping.
Args:
tensor_key: the name of the `TFExample` feature to read the tensor from.
shape_keys: Optional name or list of names of the TF-Example feature in
which the tensor shape is stored. If a list, then each corresponds to
one dimension of the shape.
shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
reshaped accordingly.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if both `shape_keys` and `shape` are specified.
"""
if shape_keys and shape is not None:
raise ValueError('Cannot specify both shape_keys and shape parameters.')
if shape_keys and not isinstance(shape_keys, list):
shape_keys = [shape_keys]
self._tensor_key = tensor_key
self._shape_keys = shape_keys
self._shape = shape
self._default_value = default_value
keys = [tensor_key]
if shape_keys:
keys.extend(shape_keys)
super(Tensor, self).__init__(keys)
def tensors_to_item(self, keys_to_tensors):
tensor = keys_to_tensors[self._tensor_key]
shape = self._shape
if self._shape_keys:
shape_dims = []
for k in self._shape_keys:
shape_dim = keys_to_tensors[k]
if isinstance(shape_dim, sparse_tensor.SparseTensor):
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
shape_dims.append(shape_dim)
shape = array_ops.reshape(array_ops.stack(shape_dims), [-1])
if isinstance(tensor, sparse_tensor.SparseTensor):
if shape is not None:
tensor = sparse_ops.sparse_reshape(tensor, shape)
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
else:
if shape is not None:
tensor = array_ops.reshape(tensor, shape)
return tensor
class LookupTensor(Tensor):
"""An ItemHandler that returns a parsed Tensor, the result of a lookup."""
def __init__(self,
tensor_key,
table,
shape_keys=None,
shape=None,
default_value=''):
"""Initializes the LookupTensor handler.
See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup.
Args:
tensor_key: the name of the `TFExample` feature to read the tensor from.
table: A tf.lookup table.
shape_keys: Optional name or list of names of the TF-Example feature in
which the tensor shape is stored. If a list, then each corresponds to
one dimension of the shape.
shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
reshaped accordingly.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if both `shape_keys` and `shape` are specified.
"""
self._table = table
super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
default_value)
def tensors_to_item(self, keys_to_tensors):
unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
return self._table.lookup(unmapped_tensor)
class BackupHandler(ItemHandler):
"""An ItemHandler that tries two ItemHandlers in order."""
def __init__(self, handler, backup):
"""Initializes the BackupHandler handler.
If the first Handler's tensors_to_item returns a Tensor with no elements,
the second Handler is used.
Args:
handler: The primary ItemHandler.
backup: The backup ItemHandler.
Raises:
ValueError: if either is not an ItemHandler.
"""
if not isinstance(handler, ItemHandler):
raise ValueError('Primary handler is of type %s instead of ItemHandler'
% type(handler))
if not isinstance(backup, ItemHandler):
raise ValueError('Backup handler is of type %s instead of ItemHandler'
% type(backup))
self._handler = handler
self._backup = backup
super(BackupHandler, self).__init__(handler.keys + backup.keys)
def tensors_to_item(self, keys_to_tensors):
item = self._handler.tensors_to_item(keys_to_tensors)
return control_flow_ops.cond(
pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
false_fn=lambda: item)
class SparseTensor(ItemHandler):
"""An ItemHandler for SparseTensors."""
def __init__(self,
indices_key=None,
values_key=None,
shape_key=None,
shape=None,
densify=False,
default_value=0):
"""Initializes the Tensor handler.
Args:
indices_key: the name of the TF-Example feature that contains the ids.
Defaults to 'indices'.
values_key: the name of the TF-Example feature that contains the values.
Defaults to 'values'.
shape_key: the name of the TF-Example feature that contains the shape.
If provided it would be used.
shape: the output shape of the SparseTensor. If `shape_key` is not
provided this `shape` would be used.
densify: whether to convert the SparseTensor into a dense Tensor.
default_value: Scalar value to set when making dense for indices not
specified in the `SparseTensor`.
"""
indices_key = indices_key or 'indices'
values_key = values_key or 'values'
self._indices_key = indices_key
self._values_key = values_key
self._shape_key = shape_key
self._shape = shape
self._densify = densify
self._default_value = default_value
keys = [indices_key, values_key]
if shape_key:
keys.append(shape_key)
super(SparseTensor, self).__init__(keys)
def tensors_to_item(self, keys_to_tensors):
indices = keys_to_tensors[self._indices_key]
values = keys_to_tensors[self._values_key]
if self._shape_key:
shape = keys_to_tensors[self._shape_key]
if isinstance(shape, sparse_tensor.SparseTensor):
shape = sparse_ops.sparse_tensor_to_dense(shape)
elif self._shape:
shape = self._shape
else:
shape = indices.dense_shape
indices_shape = array_ops.shape(indices.indices)
rank = indices_shape[1]
ids = math_ops.to_int64(indices.values)
indices_columns_to_preserve = array_ops.slice(
indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
new_indices = array_ops.concat(
[indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)
tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
if self._densify:
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
return tensor
class Image(ItemHandler):
"""An ItemHandler that decodes a parsed Tensor as an image."""
def __init__(self,
image_key=None,
format_key=None,
shape=None,
channels=3,
dtype=dtypes.uint8,
repeated=False,
dct_method=''):
"""Initializes the image.
Args:
image_key: the name of the TF-Example feature in which the encoded image
is stored.
format_key: the name of the TF-Example feature in which the image format
is stored.
shape: the output shape of the image as 1-D `Tensor`
[height, width, channels]. If provided, the image is reshaped
accordingly. If left as None, no reshaping is done. A shape should
be supplied only if all the stored images have the same shape.
channels: the number of channels in the image.
dtype: images will be decoded at this bit depth. Different formats
support different bit depths.
See tf.image.decode_image,
tf.decode_raw,
repeated: if False, decodes a single image. If True, decodes a
variable number of image strings from a 1D tensor of strings.
dct_method: An optional string. Defaults to empty string. It only takes
effect when image format is jpeg, used to specify a hint about the
algorithm used for jpeg decompression. Currently valid values
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
"""
if not image_key:
image_key = 'image/encoded'
if not format_key:
format_key = 'image/format'
super(Image, self).__init__([image_key, format_key])
self._image_key = image_key
self._format_key = format_key
self._shape = shape
self._channels = channels
self._dtype = dtype
self._repeated = repeated
self._dct_method = dct_method
def tensors_to_item(self, keys_to_tensors):
"""See base class."""
image_buffer = keys_to_tensors[self._image_key]
image_format = keys_to_tensors[self._format_key]
if self._repeated:
return functional_ops.map_fn(lambda x: self._decode(x, image_format),
image_buffer, dtype=self._dtype)
else:
return self._decode(image_buffer, image_format)
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: The tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`. If image
format is `raw`, all images are expected to be in this format, otherwise
this op can decode a mix of `jpg` and `png` formats.
Returns:
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_image():
"""Decodes a image based on the headers."""
return image_ops.decode_image(image_buffer, channels=self._channels)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
return image_ops.decode_jpeg(
image_buffer, channels=self._channels, dct_method=self._dct_method)
def check_jpeg():
"""Checks if an image is jpeg."""
# For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
# in order to feed the jpeg specify parameter 'dct_method'.
return control_flow_ops.cond(
image_ops.is_jpeg(image_buffer),
decode_jpeg,
decode_image,
name='cond_jpeg')
def decode_raw():
"""Decodes a raw image."""
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
image = control_flow_ops.case(
pred_fn_pairs, default=check_jpeg, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
# 初始化指定需要解码哪些字段,以及对应的 handler
# decode 方法调用 parsing_ops.parse_single_example 解码对应的字段,调用对应字段的 hander 解码
class TFExampleDecoder(data_decoder.DataDecoder):
"""A decoder for TensorFlow Examples.
Decoding Example proto buffers is comprised of two stages: (1) Example parsing
and (2) tensor manipulation.
In the first stage, the tf.parse_example function is called with a list of
FixedLenFeatures and SparseLenFeatures. These instances tell TF how to parse
the example. The output of this stage is a set of tensors.
In the second stage, the resulting tensors are manipulated to provide the
requested 'item' tensors.
To perform this decoding operation, an ExampleDecoder is given a list of
ItemHandlers. Each ItemHandler indicates the set of features for stage 1 and
contains the instructions for post_processing its tensors for stage 2.
"""
def __init__(self, keys_to_features, items_to_handlers):
"""Constructs the decoder.
Args:
keys_to_features: a dictionary from TF-Example keys to either
tf.VarLenFeature or tf.FixedLenFeature instances. See tensorflow's
parsing_ops.py.
items_to_handlers: a dictionary from items (strings) to ItemHandler
instances. Note that the ItemHandler's are provided the keys that they
use to return the final item Tensors.
"""
self._keys_to_features = keys_to_features
self._items_to_handlers = items_to_handlers
def list_items(self):
"""See base class."""
return list(self._items_to_handlers.keys())
def decode(self, serialized_example, items=None):
"""Decodes the given serialized TF-example.
Args:
serialized_example: a serialized TF-example tensor.
items: the list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
the decoded items, a list of tensor.
"""
example = parsing_ops.parse_single_example(serialized_example,
self._keys_to_features)
# Reshape non-sparse elements just once, adding the reshape ops in
# deterministic order.
for k in sorted(self._keys_to_features):
v = self._keys_to_features[k]
if isinstance(v, parsing_ops.FixedLenFeature):
example[k] = array_ops.reshape(example[k], v.shape)
if not items:
items = self._items_to_handlers.keys()
outputs = []
for item in items:
handler = self._items_to_handlers[item]
keys_to_tensors = {key: example[key] for key in handler.keys}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs