How to use tf.estimator

How to Use tf.estimator.Estimator to train/evaluate/predict ?

1. Pipeline Overview

  • create tfrecord dataset
  • create input function
  • create model function for estimator
  • create estimatorspec
  • create trainspec
  • start to train

The access point is located in model_main.py.

from absl import flags
from object_detection import model_lib
flags.DEFINE_...() # flags include: model_dir, pipeline_config_path, num_train_steps, eval_training_data=True/False, checkpoint_dir
# pay attention the flag checkpoint_dir, If '`checkpoint_dir` is provided, this binary operates in eval-only mode, ', writing resulting metrics to `model_dir`.
FLAGS = flags.FLAGS
def main(unused_argv):
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(...)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fns = train_and_eval_dict['eval_input_fns']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']
train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])

We have 4 very important variable, config, train_and_eval_dict, train_spec, eval_specs, especially the train_and_eval_dict, contains all the input functions’ name and an estimator instance.

2. Create TFRecord Files

3. Create Input Functions

Pay attention to this 4 lines in model_main.py:

train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']

In model_lib.create_estimator_and_inputs(...), we can see

3.1. Signature of Input Function

3.2. Create Input Function from Configuration

4. Create Estimator Instance

4.1. How to create estimator in model_lib

Let’s look into the function model_lib.create_estimator_and_inputs(run_config, pipeline_config_path, model_fn_creator=create_model_fn, ...), the complete parameters list as below:
run_config, hparams, pipeline_config_path, config_override=None, train_steps=None, sample_1_of_n_eval_examples=1, sample_1_of_n_eval_on_train_examples=1, model_fn_creator=create_model_fn, use_tpu_estimator=False, use_tpu=False, num_shards=1, params=None, override_eval_num_epochs=True, save_final_config=False, **kwargs

So we get this pipeline:
pipeline_config_path -> configs -> model_config. Then from this model_config, we can get a detection_model_fn:
detection_model_fn = functools.partial(model_builder.build, model_config=model_config), and it will be used in this line:
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)

That means we use model_fn_creator to create a model_fn throught a detection_model_fn, actually the detection_model_fn is a class name of DetectionModel, this can be seen in model_builder.build function.

Here function and class are very similar, for example, we have defined a fn() and a class(), we can call fn() to return an object or class() to return an instance ( also is an object). So I think fn() and class() are very similar in some extent. We can use fn=class, and call fn() to get an object (actually it is an instance of class()).

For example, the name model_fn equals to model_fn_creator(detection_model_fn, configs, hparams, use_tpu), equals to create_model_fn(detection_model_fn, configs, hparams, use_tpu), and equals to model_fn returned by call create_model_fn() function. That is to say, ‘model_fn’ in ‘create_estimator_and_inputs(run_config, pipeline_config_path, model_fn_creator=create_model_fn)’ equals to the function ‘model_fn’ in ‘model_lib.create_model_fn(detection_model_fn, configs, hparams, use_tpu=False)’. We call model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu), actually we call model_fn = create_model_fn(detection_model_fn, configs, hparams, use_tpu=False), and we get a function name model_fn.

The model_fn should have this signature definition:
def model_fn(features, labels, mode, params=None): (as shown in function create_model_fn(detection_model_fn, configs, hparams, use_tpu=False)).
Let’s go inside this function and see what it does.

def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
  """Creates a model function for `Estimator`.
  Args:
    detection_model_fn: Function that returns a `DetectionModel` instance.
    configs: Dictionary of pipeline config objects.
    hparams: `HParams` object.
    use_tpu: Boolean indicating whether model should be constructed for
        use on TPU.
  Returns:
    `model_fn` for `Estimator`.
  """
  train_config = configs['train_config']
  eval_input_config = configs['eval_input_config']
  eval_config = configs['eval_config']
    def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.
    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.
    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = (mode==tf.estimator.ModeKeys.TRAIN)
    labels = unstack_batch( labels, unpad_groundtruth_tensors)
    detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu))
    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      detection_model.provide_groundtruth(...)
    preprocessed_images = features[fields.InputDataFields.image]
    prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape])
    if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
      detections = detection_model.postprocess( prediction_dict, features[fields.InputDataFields.true_image_shape])
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
      losses = [loss_tensor for loss_tensor in losses_dict.values()]
      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          update_ops=detection_model.updates(),
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.
    if mode == tf.estimator.ModeKeys.PREDICT:
      exported_output = exporter_lib.add_output_tensor_nodes(detections)
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(exported_output)
      }
return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs,
          scaffold=scaffold)

To put it simply, the step like this:

  1. The model_fn() function receive features, labels, mode and parameters as parameters, 2. Create a detection model by detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) (actually the detection_model_fn is a model.DetectionModel class name, SSDMetaArch or FasterRCNNMetaArch, we call detection_model_fn() will create a model.DetectionModel instance).
  2. Use this detection_model to provide_groundtruth, to get prediction_dict, to get postprocessed detections, to define loss, to create train_op, to create optimizer …
  3. Return a EstimatorSpec instance.
  4. The model_fn() will be used by creating estimator.
  5. I think the main purpose is to decrease the arguments from model_fn_creator to model_fn.

5. Create EstimatorSpec Instance

6. Create TrainSpec Instance

7. Start to Train and Evaluate or Predict

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值