TensorFlow 版本:1.11.0
在 TensorFlow 1.4 版本中,Google 新引入了一个新 API:tf.estimator.train_and_evaluate
。提出这个 API 的目的是:代替 tf.contrib.learn.Experiment
。
1. tf.estimator.train_and_evaluate
简介
train_and_evaluate
API 用来 train
然后 evaluate
一个 Estimator。调用方式如下:
tf.estimator.train_and_evaluate(
estimator,
train_spec,
eval_spec
)
这个函数除了 train 和 evaluate 之外,还可选的提供了模型的导出功能,这样就可以把一个训练好的模型直接转交给业务部门来使用了,可以算是“产学研”一条龙服务了。
该函数的参数有三个:
estimator
:一个Estimator
实例。train_spec
:一个TrainSpec
实例。用来配置训练过程。eval_spec
:一个EvalSpec
实例。用来配置评估过程、(可选)模型的导出。
该函数的返回值有一个:
Estimator.evaluate
的结果 及 前面指定的ExportStrategy
的输出结果。当前,尚未定义分布式训练模式的返回值。
实际上,如果直接使用 Estimator API,完成 train 和 evaluate 已经是很简单的任务了,为什么我们还要使用 train_and_evaluate 这个函数呢?按官方文档的说法:这个函数可以保证 本地 和 分布式 环境下行为的一致性。也就是说,使用 Estimator
和 train_and_evaluate
编写的程序同时支持本地、集群上的训练,而不需要修改任何代码。可以想像一下,在完成了本地 CPU 训练的测试之后,直接 push 到 Cloud ML Engine 上,分分钟完成一个模型的训练,甚至还可以直接使用 TPU 集群(只要你保证模型里的 op 都是对 TPU 兼容的),多么方便的一个工具啊!
这个函数默认的分布式策略是:parameter server-based between-graph replication。对于其它的分布式策略的使用,可以参照 DistributionStrategies 。TensorFlow 关于分布式的官方文档见 Distributed TensorFlow。
当然,方便的背后一般都有代价。为了保证代码在本地和集群上都可以正常终止,所以只能使用 Estimator 的 max_steps
参数设定终止条件。所以,如果想使用别的方式终止训练,可能就需要一些“技巧”了。
2. 参数说明
上面我们已经知道 train_and_evaluate 有三个参数,第一个先放在一边,因为这个参数就是一个 Estimator 的实例。我们先来看一下另外两个参数:
2.1 train_spec
参数
train_spec
参数接收一个 tf.estimator.TrainSpec
实例。
# TrainSpec的参数
__new__(
cls, # 这个参数不用指定,忽略即可。
input_fn,
max_steps=None,
hooks=None
)
其中:
input_fn
: 参数用来指定数据输入。max_steps
: 参数用来指定训练的最大步数,这是训练的唯一终止条件。hooks
: 参数用来挂一些tf.train.SessionRunHook
,用来在 session 运行的时候做一些额外的操作,比如记录一些 TensorBoard 日志什么的。
2.2 eval_spec
参数
eval_spec
参数接收一个 tf.estimator.EvalSpec
实例。相比 TrainSpec
,EvalSpec
的参数多很多。因为 EvalSpec
不仅可以指定评估过程,还可以指定导出模型的功能(可选)。
__new__(
cls, # 这个参数不用指定,忽略即可。
input_fn,
steps=100, # 评估的迭代步数,如果为None,则在整个数据集上评估。
name=None,
hooks=None,
exporters=None,
start_delay_secs=120,
throttle_secs=600
)
其中:
input_fn
: 含义同2.1。steps
: 用来指定评估的迭代步数,如果为None,则在整个数据集上评估。name
:如果要在多个数据集上进行评估,通过name
参数可以保证不同数据集上的评估日志保存在不同的文件夹中,从而区分不同数据集上的评估日志。
不同的评估日志保存在独立的文件夹中,在 TensorBoard 中从而独立的展现。hooks
:含义同2.1exporters
:一个tf.estimator.export
模块中的类的实例。start_delay_secs
:调用train_and_evaluate
函数后,多少秒之后开始评估。第一次评估发生在start_delay_secs + throttle_secs
秒后。throttle_secs
:多少秒后又开始评估,如果没有新的 checkpoints 产生,则不评估,所以这个间隔是最小值。
3. 非分布式实例
# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
categorical_column=categorial_feature_a, ...)
... # other feature columns
estimator = DNNClassifier(
feature_columns=[categorial_feature_a_emb, ...],
hidden_units=[1024, 512, 256])
# Or set up the model directory
# estimator = DNNClassifier(
# config=tf.estimator.RunConfig(
# model_dir='/my_model', save_summary_steps=100),
# feature_columns=[categorial_feature_a_emb, ...],
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
def train_input_fn(): # returns x, y
# please shuffle the data.
pass
def eval_input_fn(): # returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
注意:在当前的实现中,estimator.evaluate
将被调用多次。这意味着在每次评估时,会重新创建评估图(包括eval_input_fn
)。estimator.train
只会被调用一次。
4. 分布式实例
上面的代码可以在不加修改的情况下用于分布式训练,但请确保所有 worker
的 RunConfig.model_dir
设置为相同的目录(例如,一个所有 worker 都可以读写的共享文件系统。唯一需要做的就是正确得设置所有 worker 的环境变量 TF_CONFIG
。
设置环境变量的方式会随系统而变化。例如,在 Linux 上,设置环境变量的方式如下($
是命令提示符):
$ TF_CONFIG='<replace_with_real_content>' python train_model.py
训练的集群配置如下:
cluster = {"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]}
chief training worker(必须有,且只能有一个)的 TF_CONFIG
应该被设置为:
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "chief", "index": 0}
}'
注意:chief worker 与其他 non-chief training worker 一样,也进行模型的训练 job。chief worker 除了进行模型训练,还管理一些其它 work(例如:checkpoint 保存、恢复,写入 summaries 等)。
non-chief training worker(可选,可以有多个)的 TF_CONFIG
应该被设置为:
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "worker", "index": 0}
}'
上面的 task.index
表示 worker 的编号。本例中,有三个 non-chief training worker,所以编号为 0,1,2。
parameter server(可以是多个)的 TF_CONFIG
应该被设置为:
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "ps", "index": 0}
}'
由于例子中参数服务器的个数为两个,所以 task.index
编号分别为 0,1。
评估的集群配置如下:
评估 task 的 TF_CONFIG
如下所示。评估是一个特殊的 task,该 task 不是训练集群的一部分。有可能只有一个。该 task 被用于模型评估。
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "evaluator", "index": 0}
}'
当 distribute
或 experimental_distribute.train_distribute
及 experimental_distribute.remote_cluster
被设置时,这个方法将开始在本机运行一个 client,该 client 将连接到 remote_cluster
,以进行训练和评估。
参考文档:
tf.estimator.train_and_evaluate
官方文档(英文)- tf.estimator.train_and_evaluate 试用
- 推荐一个 Estimator+Experiment 的实例:tensorflow/models里的cifar10_estimator