learn_runner
使用接口
from tensorflow.contrib.learn import learn_runner
def run_experiment(argv=None):
learn_runner.run(
experiment_fn=experiment_fn, # First-class function
schedule=_schedule, # What to run "train" or "train_and_evaluate"
run_config=run_config, # RunConfig
hparams=params # HParams
)
def experiment_fn(run_config, params):
"""Create an experiment to train and evaluate the model.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParam): Hyperparameters
Returns:
(Experiment) Experiment for training the mnist model.
"""
# Define the estimator
estimator = get_estimator(run_config, params)
# Setup data loaders
# mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)
train_input_fn, train_input_hook = get_train_inputs(
params.train_batch_size, params.dataset_dir, params.dataset_file_pattern)
eval_input_fn, eval_input_hook = get_val_inputs(
params.eval_batch_size, params.dataset_dir, params.dataset_file_pattern)
# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.eval_min_frequency, # Eval frequency
# train_monitors=[], # Hooks for training
# eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=params.eval_steps # Use evaluation feeder until its empty
)
return experiment
run()
def run(experiment_fn, schedule=None, run_config=None,
hparams=None):
"""
Desc:
It creates an Experiment by calling `experiment_fn`. Then it calls the
function named as `schedule` of the Experiment.
If schedule is not provided, then the default schedule for the current task
type is used. The defaults are as follows:
* 'ps' maps to 'serve'
* 'worker' maps to 'train'
* 'master' maps to 'local_run'
If the experiment's config does not include a task type, then an exception
is raised.
Args:
experiment_fn: A function that creates an `Experiment`.
It accepts two arguments `run_config` and `hparams`, which should be
used to create the `Estimator` (`run_config` passed as `config` to its
constructor; `hparams` used as the hyper-parameters of the model).
It must return an `Experiment`.
schedule: The name of the method in the `Experiment` to run.
run_config: `RunConfig` instance. The `run_config.model_dir` must be
non-empty.
hparams: `HParams` instance. The default hyper-parameters, which will be
passed to the `experiment_fn` if `run_config` is not None.
Returns:
The return value of function `schedule`.
"""
# 1. get experiment
# wrapped for checking the uid
wrapped_experiment_fn = _wrapped_experiment_fn_with_uid_check(experiment_fn)
experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams)
# 2. Get the schedule
run_config = run_config or experiment.estimator.config
schedule = schedule or _get_default_schedule(run_config)
def _execute_schedule(experiment, schedule):
"""Execute the method named `schedule` of `experiment`."""
task = getattr(experiment, schedule)
return task()
def _get_default_schedule(config):
"""Returns the default schedule for the provided RunConfig."""
if not config or not _is_distributed(config):
return 'train_and_evaluate'
if not config.task_type:
raise ValueError('Must specify a schedule')
if config.task_type == run_config_lib.TaskType.MASTER:
# TODO(rhaertel): handle the case where there is more than one master
# or explicitly disallow such a case.
return 'train_and_evaluate'
elif config.task_type == run_config_lib.TaskType.PS:
return 'run_std_server'
elif config.task_type == run_config_lib.TaskType.WORKER:
return 'train'
def _is_distributed(config):
"""Returns true if this is a distributed job."""
if not config.cluster_spec:
return False
# This is considered a distributed job if there is more than one task
# in the cluster spec.
task_count = 0
for job in config.cluster_spec.jobs:
for _ in config.cluster_spec.job_tasks(job):
task_count += 1
return task_count > 1
run_config
参考:
https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig
https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig#master
experiment
核心
- 触发初始化:
experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams)
def experiment_fn(run_config, params):
# .......
# get estimator
estimator = get_estimator()
# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.eval_min_frequency, # Eval frequency
# train_monitors=[], # Hooks for training
# eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=params.eval_steps # Use evaluation feeder until its empty
)
def get_estimator(run_config, params)