tensorflow contrib_Tensorflow中的Estimator

v2-e9b5ea41183bd7177eb58d49069e362f_1440w.jpg?source=172ae18b

最近在看Bert的源码,作者是使用Estimator来实现的数据输入,训练,预测等功能。所以,对Tensorflow中Estimator的使用做简单的总结。主要是input_fn和model_fn的使用。

参考:

https://blog.csdn.net/wwangfabei1989/article/details/90516318​blog.csdn.net 4. Tensorflow的Estimator实践原理​www.cnblogs.com
v2-33640c34ff5756917f38fcc27e34ba56_ipico.jpg

第一部分 : 数据的输入——input_fn

def 

总结:在bert中,实际上是先把数据转换成tfrecord形式,然后,在input_fn_builder中tfrecord文件进行读取,最后生成字典。其以Batch的形式,只返回了一个字典features,该字典中已经包含了labels的信息。

第二部分 : 模型的建立以及操作(train, eval, predict)——model_fn

model_fn可以拆分为两大部分:create_model,model_fn_builder(返回model_fn)

def 

重点关注mode_fn的参数。这四个参数不需要明确的人为指定。1. 对于前两个参数features, labels,只需要使用数据部分的返回值input_fn作为输入,则tensorflow会自动按顺序选择input_fn的前两个参数作为features,labels的输入;2 . 第三个参数 mode。当有人调用train、evaluate或predict时,Estimator框架会调用模型函数并根据调用方式将mode参数设置为ModeKeys.TRAIN,ModeKeys.EVAL,ModeKeys.PREDICT三个值中的一个; 3. params是可以人为指定的传入的参数列表; 4. config一般由run_config = tf.contrib.tpu.RunConfig(...)定义,具体可参考bert源码。

这里单独说一下参数mode。理论上,模型函数需要提供代码来处理全部三个mode值(当然如果代码是自己使用,而你只需要预测和训练,为了省事也可以不写Eval模式)。对于每个mode值,代码都必须返回 tf.estimator.EstimatorSpec的一个实例,其中包含调用程序所需的信息。我们来详细了解各个mode。

(1)训练

if 

对于训练模式,tf.estimator.EstimatorSpec的三个必须参数为mode, loss以及train_op。mode由调用时,自动按照调用方式赋值,loss由create_model的返回值得到或是进一步计算得到。train_op需要定义好优化器。最简单的train_op可以由下定义:

optimizer 

(2)评估

elif 

对于评估模式,tf.estimator.EstimatorSpec的三个必须参数为mode, loss以及eval_metrics。mode在调用estimator时,自动按照调用方式赋值,loss由create_model的返回值得到或是进一步计算得到。eval_metrics是一个字典,在bert中给出的句子分类的例子中,其为

eval_metrics 

评估模式的特点是,eval_metrics中的变量可以取出,并且直接打印出来。

(3)预测

elif 

对于预测模式,tf.estimator.EstimatorSpec的两个必须参数为mode, 以及predictions。mode在调用estimator时,自动按照调用方式赋值。predictions是一个字典,对于bert中的句子分类的例子,预测值只需要返回每一个句子的概率即可,predictions={"probabilities": probabilities}。然而,如果对于序列标注任务,则可以使 predictions = { "label_ids": label_ids, "predicts": predicts },之后,根据label_ids和predicts计算recall等指标。

第三部分:Estimator的建立,以及调用

前面对于Estimator的准备工作都已经完成了,即数据输入部分有了,模型建立部分有了,模型的三个功能也有了。接下来就是要实际的把模型跑起来。

(1)建立Estimator

run_config 

model_fn和run_config可以参考bert源码。可以看到,建立一个estimator需要指定gpu,tpu或是cpu使用,model_fn, run_config,以及训练,评估,预测的Batch_size。前三个应该为必须参数,batch_size在内部直接指定也可以,这些batch_size属于params里面的参数。

(2)Estimator的调用

if 

最后,有一个问题params["batch_size"]这个是在哪里传入的? 应该是estimator建立的时候,传入的params参数,但总感觉对不上。有待解决。。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值