TensorFlow高级API系列(二):从源码看如何自定义estimator

本文深入探讨TensorFlow的高级API Estimator,重点解析如何从源码角度理解并自定义model_fn,包括model_fn的主要参数、mode的使用以及config配置。通过实例介绍Estimator的保存与恢复策略,如save_summary_steps、save_checkpoints_steps等。
摘要由CSDN通过智能技术生成

源码解析

上一篇博客,实例化estimator的代码如下:

tf.estimator.Estimator(
    model_fn=model_fn,  # First-class function
    params=params,  # HParams
    config=run_config  # RunConfig
)

我们从这个实例化进入,看我们需要传给estimator的参数都是些什么?上面三个代码不是全部参数,看源码

class Estimator(object):
  def __init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None):
  ...

可以看到需要传入的参数如下:

model_dir: 指定checkpoints和其他日志存放的路径。
model_fn: 这个是需要我们自定义的网络模型函数,后面详细介绍
config: 用于控制内部和checkpoints等,如果model_fn函数也定义config这个变量,则会将config传给model_fn
params: 该参数的值会传递给model_fn。
warm_start_from: 指定checkpoint路径,会导入该checkpoint开始训练

其中最重要的就是model_fn,params是服务于这个模型。

model_fn
这个接受的是我们自定义的模型,下面看一下它的主要参数。
模型的定义一般如下:

def my_model_fn(
   features,    # This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).
   labels,     # This is batch_labels from input_fn
   mode,      # An instance of tf.estimator.ModeKeys
   params,      # Additional configuration
   config=None
   ):
  • 前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。
  • params
    是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params[‘n_classes’]来定义最终输出节点的个数等。
  • config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
  • mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL /
    PREDICT
    来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(…)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN。
    model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个tf.estimator.EstimatorSpec的实例。

如下实例:

def my_model_fn(features,labels,mode,params):
    #输入层,feature_columns对应Classifier(feature_columns=...)
    net 
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值