estimator使用

一、model_fn

函数有5个输入参数features, labels, mode, params, config,并输出一个EstimatorSpec实例;

  • featuresinput_fn的第一个输出。
  • labelsinput_fn的第二个输出。
  • mode:操作类型(是训练、预测还是评估),对应tf.estimator.ModeKeys.EVAL/TRAIN/PREDICT
  • params:定义Estimator实例时传入的params属性。
  • config:定义Estimator实例时传入的config属性。
  • 输出EstimatorSpec实例介绍:
    • 训练时:需要指定losstrain_op
    • 预测时:需要指定predictions
    • 评估时:需要指定lossmetrics

二、实例化Estimator

  • config参数:
    • 可用于设置训练过程中相关操作,主要就是summary/save/logging操作。
    • 可用于设置 tf.Session 的配置,session_config 其实就是 tf.ConfigProto 对象。
    • 可用于设置多GPU训练,即train_distribute变量。
  • params参数:
    • 主要作用就是可以传入model_fn中,帮助实现各类功能。
    • 模型参数:以Faster R-CNN为例,可以选择backbone参数,anchors参数,weight decay参数等。
    • 训练参数:如优化器类型及参数、学习率参数。
    • 性能指标:如选择那些性能指标进行计算等。
  • model_fn
  • model_dir:summary和save的路径
  • configtf.estimator.RunConfig实例
  • params:输入参数,会传输到 model_fn 中。
  • warm_start_from:热启动功能,暂时没碰到做啥用的

三、tf.estimator.Estimator训练、预测、评估

def train(self,
        input_fn,
        hooks=None,
        steps=None,
        max_steps=None,
        saving_listeners=None):
# input_fn 在`1. 数据集`中介绍
# predict_keys 字符串列表,当EstimatorSpec.predictions是字典时使用
# hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务
# checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件
def predict(self,
          input_fn,
          predict_keys=None,
          hooks=None,
          checkpoint_path=None):
# input_fn 在`1. 数据集`中介绍
# steps 评估次数最大值
# hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务
# checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件
# name 名称,好像用于记录不同数据集上的结果,将评估结果保存到不同文件夹中
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
           name=None):

四、tf.data

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

参数说明:

batch:指更新梯度中使用的样本数;

repeat:将数据重复多次,主要用来处理epoch;

shuffle:打乱dataset中的元素;

map:Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的DataSet;

 

 

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值