tf.estimator是比tf.contrib.slim更高级的API,能同时训练和验证模型。[更多]
- tf.estimator的有些参数是函数,但又不能带参数,如:
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9) model = tf.estimator.Estimator( model_fn=deeplab_model.deeplabv3_plus_model_fn, model_dir=FLAGS.model_dir, config=run_config, params={...})
其中model_fn就只是函数名,其参数在params中指定。
- tf.estimator.Estimator.train和tf.estimator.Estimator.evaluate的input_fn也不能带参数,如:
tf.logging.info("Start training.") model.train( input_fn=lambda: input_fn(True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval), hooks=train_hooks, # steps=1 # For debug ) tf.logging.info("Start evaluation.") eval_results = model.evaluate( input_fn=lambda: input_fn(False, FLAGS.data_dir, 1), hooks=eval_hooks, # steps=1 # For debug )
input_fn函数不能带参数,所以使用(1)lambda方式。此外还可以使用python的(2)functools.partial函数,如:
model.train( input_fn=functools.partial(True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval), hooks=train_hooks, # steps=1 # For debug )
还可以向上面的model_fn那样,(2)单独定义函数,再传入函数名,或者使用(4)python wrapper修饰器。
- 结束。