How to Use tf.estimator.Estimator to train/evaluate/predict ?
- 4. Create Estimator Instance
- 4.1. How to create estimator in model_lib
- 5. Create EstimatorSpec Instance
- 6. Create TrainSpec Instance
- 7. Start Train Operation
1. Pipeline Overview
- create tfrecord dataset
- create input function
- create model function for estimator
- create estimatorspec
- create trainspec
- start to train
The access point is located in
model_main.py
.
from absl import flags
from object_detection import model_lib
flags.DEFINE_...() # flags include: model_dir, pipeline_config_path, num_train_steps, eval_training_data=True/False, checkpoint_dir
# pay attention the flag checkpoint_dir, If '`checkpoint_dir` is provided, this binary operates in eval-only mode, ', writing resulting metrics to `model_dir`.
FLAGS = flags.FLAGS
def main(unused_argv):
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(...)
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fns,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_on_train_data=False)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
We have 4 very important variable, config, train_and_eval_dict, train_spec, eval_specs
, especially the train_and_eval_dict
, contains all the input functions’ name and an estimator instance.
2. Create TFRecord Files
3. Create Input Functions
Pay attention to this 4 lines in model_main.py
:
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
In model_lib.create_estimator_and_inputs(...)
, we can see
3.1. Signature of Input Function
3.2. Create Input Function from Configuration
4. Create Estimator Instance
4.1. How to create estimator in model_lib
Let’s look into the function model_lib.create_estimator_and_inputs(run_config, pipeline_config_path, model_fn_creator=create_model_fn, ...)
, the complete parameters list as below:
run_config, hparams, pipeline_config_path, config_override=None, train_steps=None, sample_1_of_n_eval_examples=1, sample_1_of_n_eval_on_train_examples=1, model_fn_creator=create_model_fn, use_tpu_estimator=False, use_tpu=False, num_shards=1, params=None, override_eval_num_epochs=True, save_final_config=False, **kwargs
So we get this pipeline:
pipeline_config_path -> configs -> model_config. Then from this model_config, we can get a detection_model_fn
:
detection_model_fn = functools.partial(model_builder.build, model_config=model_config)
, and it will be used in this line:
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
That means we use model_fn_creator to create a model_fn throught a detection_model_fn, actually the detection_model_fn
is a class name of DetectionModel
, this can be seen in model_builder.build
function.
Here function and class are very similar, for example, we have defined a fn() and a class(), we can call fn() to return an object or class() to return an instance ( also is an object). So I think fn() and class() are very similar in some extent. We can use fn=class, and call fn() to get an object (actually it is an instance of class()).
For example, the name model_fn
equals to model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
, equals to create_model_fn(detection_model_fn, configs, hparams, use_tpu)
, and equals to model_fn
returned by call create_model_fn() function. That is to say, ‘model_fn’ in ‘create_estimator_and_inputs(run_config, pipeline_config_path, model_fn_creator=create_model_fn)’ equals to the function ‘model_fn’ in ‘model_lib.create_model_fn(detection_model_fn, configs, hparams, use_tpu=False)’. We call model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
, actually we call model_fn = create_model_fn(detection_model_fn, configs, hparams, use_tpu=False)
, and we get a function name model_fn
.
The model_fn
should have this signature definition:
def model_fn(features, labels, mode, params=None):
(as shown in function create_model_fn(detection_model_fn, configs, hparams, use_tpu=False)
).
Let’s go inside this function and see what it does.
def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
"""Creates a model function for `Estimator`.
Args:
detection_model_fn: Function that returns a `DetectionModel` instance.
configs: Dictionary of pipeline config objects.
hparams: `HParams` object.
use_tpu: Boolean indicating whether model should be constructed for
use on TPU.
Returns:
`model_fn` for `Estimator`.
"""
train_config = configs['train_config']
eval_input_config = configs['eval_input_config']
eval_config = configs['eval_config']
def model_fn(features, labels, mode, params=None):
"""Constructs the object detection model.
Args:
features: Dictionary of feature tensors, returned from `input_fn`.
labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
otherwise None.
mode: Mode key from tf.estimator.ModeKeys.
params: Parameter dictionary passed from the estimator.
Returns:
An `EstimatorSpec` that encapsulates the model and its serving
configurations.
"""
params = params or {}
total_loss, train_op, detections, export_outputs = None, None, None, None
is_training = (mode==tf.estimator.ModeKeys.TRAIN)
labels = unstack_batch( labels, unpad_groundtruth_tensors)
detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu))
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
detection_model.provide_groundtruth(...)
preprocessed_images = features[fields.InputDataFields.image]
prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape])
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
detections = detection_model.postprocess( prediction_dict, features[fields.InputDataFields.true_image_shape])
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
losses_dict = detection_model.loss(
prediction_dict, features[fields.InputDataFields.true_image_shape])
losses = [loss_tensor for loss_tensor in losses_dict.values()]
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=is_training)
graph_rewriter_fn()
global_step = tf.train.get_or_create_global_step()
training_optimizer, optimizer_summary_vars = optimizer_builder.build(
train_config.optimizer)
train_op = tf.contrib.layers.optimize_loss(
loss=total_loss,
global_step=global_step,
learning_rate=None,
clip_gradients=clip_gradients_value,
optimizer=training_optimizer,
update_ops=detection_model.updates(),
variables=trainable_variables,
summaries=summaries,
name='') # Preventing scope prefix on all variables.
if mode == tf.estimator.ModeKeys.PREDICT:
exported_output = exporter_lib.add_output_tensor_nodes(detections)
export_outputs = {
tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
tf.estimator.export.PredictOutput(exported_output)
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=detections,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=export_outputs,
scaffold=scaffold)
To put it simply, the step like this:
- The
model_fn()
function receive features, labels, mode and parameters as parameters, 2. Create a detection model bydetection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu))
(actually the detection_model_fn is amodel.DetectionModel
class name,SSDMetaArch
orFasterRCNNMetaArch
, we call detection_model_fn() will create a model.DetectionModel instance). - Use this detection_model to provide_groundtruth, to get prediction_dict, to get postprocessed detections, to define loss, to create train_op, to create optimizer …
- Return a EstimatorSpec instance.
- The model_fn() will be used by creating estimator.
- I think the main purpose is to decrease the arguments from
model_fn_creator
tomodel_fn
.