Tenorflow -- estimator

目录:
1、构建Estimator和model_fn
2、使用 tf.estimator.TrainSpec 指定训练输入函数及相关参数
3、使用 tf.estimator.EvalSpec 指定验证输入函数及相关参数
4、使用 tf.estimator.train_and_evaluate 启动训练和验证过程
5、使用estimator.export_savedmodel导出模型

其中5可选

tf.estimator是tensorflow的high level api
在这里插入图片描述

1、构建Estimator(包含构建model_fn):

estimator = tf.estimator.Estimator(model_fn, model_dir=None, config=None,
                       params=None, warm_start_from=None)

其中
model_fn 是模型函数;
model_dir 是训练时模型保存的路径;
config 是 tf.estimator.RunConfig 的配置对象;
params 是传入 model_fn 的超参数字典;
warm_start_from 或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数。

model_fn需要自己构建,参数这里不传入,后面训练或验证时传入。
具体参数细节这个博客讲得很好:https://www.cnblogs.com/marsggbo/p/11232897.html

2.1 model_fn

输入和输出格式固定,不可更改

def my_model_fn(
   features, 	# This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).
   labels,     # This is batch_labels from input_fn
   mode,      # An instance of tf.estimator.ModeKeys
   params,  	# Additional configuration
   config=None
   ):
   # model具体定义
   # 返回tf.estimator.EstimatorSpec格式,不同模式参数不同,具体见下面
   return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN,其它参数)
2.1.2 tf.estimator.EstimatorSpec

不同模式需要传入不同参数
根据mode的值的不同,需要不同的参数,即:

对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
对于mode == ModeKeys.EVAL:必填字段是loss.
对于mode == ModeKeys.PREDICT:必填字段是predictions.

def get_training_spec(learning_rate, joint_loss):
    """
    训练阶段的estimator构建
    """
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    gender_train_op = optimizer.minimize(
        loss=joint_loss,
        global_step=tf.train.get_global_step())
    # train必填mode、loss、train_op
    return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=joint_loss, train_op=gender_train_op)

def get_eval_spec(gender_logits, age_logits, labels, loss):
    """
    评估阶段的estimator构建
    """
    eval_metric_ops = {
        "gender_accuracy": tf.metrics.accuracy(
            labels=labels['gender'], predictions=tf.argmax(gender_logits, axis=1)),
        'age_accuracy': tf.metrics.accuracy(labels=labels['age'], predictions=tf.argmax(age_logits, axis=1)),
        'age_precision': tf.metrics.sparse_precision_at_k(labels=labels['age'],
                                                          predictions=age_logits, k=10)
    }
    # eval必填mode、loss、eval_metric_ops
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=eval_metric_ops)

def get_prediction_spec(age_logits, logits):
    """
    构建预测阶段的estimator
    
    :param age_logits: age任务用于计算softmax的score向量
    :param logits: 性别任务用于计算softmax的score向量
    :return: Estimator spec 
    """
    predictions = {
        "classes": tf.argmax(input=logits, axis=1),
        "age_class": tf.argmax(input=age_logits, name='age_class', axis=1),
        "age_prob": tf.nn.softmax(age_logits, name='age_prob'),
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }
    # tf.estimator.EstimatorSpec这个实例是用来初始化Estimator类的
    # 测试必填mode和predictions,predictions为tensor或dict形式的tensor
    return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions)

2、使用 tf.estimator.TrainSpec 指定训练输入函数及相关参数

train_spec = tf.estimator.TrainSpec(input_fn, max_steps, hooks)

input_fn 用来提供训练时的输入数据;
max_steps 指定总共训练多少步;
hooks 是一个 tf.train.SessionRunHook 对象,用来配置分布式训练等参数。

3、使用 tf.estimator.EvalSpec 指定验证输入函数及相关参数

eval_spec = tf.estimator.EvalSpec(
			    input_fn,
			    steps=100,
			    name=None,
			    hooks=None,
			    exporters=None,
			    start_delay_secs=120,
			    throttle_secs=600)

input_fn 用来提供验证时的输入数据;
steps 指定总共验证多少步(一般设定为 None 即可);
hooks 用来配置分布式训练等参数;
exporters 是一个 Exporter 迭代器,会参与到每次的模型验证;start_delay_secs 指定多少秒之后开始模型验证;t
hrottle_secs 指定多少秒之后重新开始新一轮模型验证(当然,如果没有新的模型断点保存,则该数值秒之后不会进行模型验证,因此这是新一轮模型验证需要等待的最小秒数)。

4、使用 tf.estimator.train_and_evaluate 启动训练和验证过程

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

参数为前面三步内容

5、使用estimator.export_savedmodel导出模型

estimator.export_savedmodel(export_dir_base='{}/serving'.format(args.model_dir),
                                serving_input_receiver_fn=serving_fn,
                                as_text=True)
def serving_fn():
    receiver_tensor = {
        'image': tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3])
    }

    features = {
        'image': tf.image.resize_images(receiver_tensor['image'], [224, 224])
    }

    return tf.estimator.export.ServingInputReceiver(features, receiver_tensor)

参考:
https://www.jianshu.com/p/b8930fa13ea7
https://www.cnblogs.com/marsggbo/p/11232897.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值