目录
第二部分 : 模型的建立以及操作(train, eval, predict)——model_fn
最近在看Bert的源码,作者是使用Estimator来实现的数据输入,训练,预测等功能。所以,对Tensorflow中Estimator的使用做简单的总结。主要是input_fn和model_fn的使用。
第一部分 : 数据的输入——input_fn
def input_fn_builder(params):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
//可以执行其它操作
//例如在Bert代码中,这里是一个_decode_record()函数,用于读取tfrecord文件
def input_fn(params):
"""The actual input function."""
//进行具体的操作
return features,labels
//返回的顺序要和 model_fn的输入的前两个参数一致
//或者 dataset元素 格式为(features,label)元组也可以,即return dataset.make_one_shot_iterator().get_next()
return input_fn
总结:在bert中,实际上是先把数据转换成tfrecord形式,然后,在input_fn_builder中tfrecord文件进行读取,最后生成字典。其以Batch的形式,只返回了一个字典features,该字典中已经包含了labels的信息。
第二部分 : 模型的建立以及操作(train, eval, predict)——model_fn
model_fn可以拆分为两大部分:create_model,model_fn_builder(返回model_fn)
def create_model(params):
//定义最后的网络结构 和 损失函数 以及 返回值
//bert这里相当于从modeling.py中取出模型的最后输出。然后再加入loss层。
//返回值可以为(loss, per_example_loss, logits, probabilities, predict)
def model_fn_builder(params):
"""实际创建Estimator的model_fn""" """Returns `model_fn` closure for TPUEstimator."""
//类似的,此处同样可以有其他操作,即其他函数
def model_fn(features, labels, mode, params,config) //estimator需要的model_fn参数固定
/*
features: from input_fn的返回 切记返回的顺序
labels: from input_fn 的返回 切记返回的顺序
mode: tf.estimator.ModeKeys实例的一种
params: 在初始化estimator时 传入的参数列表,dict形式,或者直接使用self.params也可以
config:初始化estimator时的Runconfig
*/
(total_loss, per_example_loss, logits, probabilities,predicts) = create_model(all_params)
output_spec = None
if mode==tf.estimator.ModeKeys.TRAIN: // 执行训练
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
loss=total_loss,
train_op=train_op,
...)
elif mode==tf.estimator.ModeKeys.EVAL: //评估
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metrics=eval_metrics,
...)
elif mode=tf.estimator.ModeKeys.PREDICT: // 预测
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions,
...)
//......其它操作
//最后返回
return output_spec
return model_fn
重点关注mode_fn的参数。这四个参数不需要明确的人为指定。1. 对于前两个参数features, labels,只需要使用数据部分的返回值input_fn作为输入,则tensorflow会自动按顺序选择input_fn的前两个参数作为features,labels的输入;2 . 第三个参数 mode。当有人调用train、evaluate或predict时,Estimator框架会调用模型函数并根据调用方式将mode参数设置为ModeKeys.TRAIN,ModeKeys.EVAL,ModeKeys.PREDICT三个值中的一个; 3. params是可以人为指定的传入的参数列表; 4. config一般由run_config = tf.contrib.tpu.RunConfig(...)定义,具体可参考bert源码。
这里单独说一下参数mode。理论上,模型函数需要提供代码来处理全部三个mode值(当然如果代码是自己使用,而你只需要预测和训练,为了省事也可以不写Eval模式)。对于每个mode值,代码都必须返回 tf.estimator.EstimatorSpec的一个实例,其中包含调用程序所需的信息。
我们来详细了解各个mode。
(1)训练
if mode==tf.estimator.ModeKeys.TRAIN: // 执行训练
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
loss=total_loss,
train_op=train_op,
...)
对于训练模式,tf.estimator.EstimatorSpec的三个必须参数为mode, loss以及train_op。mode由调用时,自动按照调用方式赋值,loss由create_model的返回值得到或是进一步计算得到。train_op需要定义好优化器。最简单的train_op可以由下定义:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
(2)评估
elif mode==tf.estimator.ModeKeys.EVAL: //评估
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metrics=eval_metrics,,
...)
对于评估模式,tf.estimator.EstimatorSpec的三个必须参数为mode, loss以及eval_metrics。mode在调用estimator时,自动按照调用方式赋值,loss由create_model的返回值得到或是进一步计算得到。eval_metrics是一个字典,在bert中给出的句子分类的例子中,其为
eval_metrics = {
"eval_accuracy": accuracy,
"eval_loss": loss,
}
评估模式的特点是,eval_metrics中的变量可以取出,并且直接打印出来。
(3)预测
elif mode=tf.estimator.ModeKeys.PREDICT: // 预测
...
output_spec = tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions,
...)
对于预测模式,tf.estimator.EstimatorSpec的两个必须参数为mode, 以及predictions。mode在调用estimator时,自动按照调用方式赋值。predictions是一个字典,对于bert中的句子分类的例子,预测值只需要返回每一个句子的概率即可,predictions={"probabilities": probabilities}。然而,如果对于序列标注任务,则可以使 predictions = { "label_ids": label_ids, "predicts": predicts },之后,根据label_ids和predicts计算recall等指标。
第三部分:Estimator的建立,以及调用
前面对于Estimator的准备工作都已经完成了,即数据输入部分有了,模型建立部分有了,模型的三个功能也有了。接下来就是要实际的把模型跑起来。
(1)建立Estimator
run_config = tf.contrib.tpu.RunConfig(...)
model_fn = model_fn_builder(...)
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
model_fn和run_config可以参考bert源码。可以看到,建立一个estimator需要指定gpu,tpu或是cpu使用,model_fn, run_config,以及训练,评估,预测的Batch_size。前三个应该为必须参数,batch_size在内部直接指定也可以,这些batch_size属于params里面的参数。
(2)Estimator的调用
if FLAGS.do_train:
train_input_fn = input_fn_builder(...)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
//train模式无返回值
if FLAGS.do_predict:
predict_input_fn = input_fn_builder(...)
result = estimator.predict(input_fn=predict_input_fn)
//这里返回的result为predictions。是一个字典
if FLAGS.do_eval:
eval_input_fn = input_fn_builder(...)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
//返回值为eval_metrics
最后,有一个问题params["batch_size"]这个是在哪里传入的? 应该是estimator建立的时候,传入的params参数,但总感觉对不上。有待解决。。