1、概述
TF1.3开始引入estimator,通过框架图可以看到Estimator是属于High level的API,而Mid-level API分别是:
- Layers:用来构建网络结构
- Datasets: 用来构建数据读取pipeline
- Metrics:用来评估网络性能
Estimator简化和抽象了管理训练、评估和预测,跟Keras类似,Estimator是模型级别的API,我们不再需要关注Session这样的操作,只需要几步简单的定义就能让模型run起来。
2、使用方法
主要步骤:
(1)实例化一个Estimator,并与model_fn绑定。
(2)定义model_fn函数,包含模型结构、优化方法等,返回。
(3)定义定义input_fn函数,作为参数传递给Estimator的train函数,最后会被model_fn消费。
下面是一个基本的模式:
(1)构建estimator
import tensorflow as tf
# 创建一个名为classifier的estimator
classifier = tf.estimator.Estimator(model_fn, model_dir=None,
config=None, params=None, warm_start_from=None)
- model_fn是模型函数,estimator做的事情就是把获取到的参数传给model_fn。
- model_dir是存储和加载检查点和事件文件的路径,如果这里不设置就会从config里读取,如果都不设置则保存到临时目录,如果都设置则必须一致。
- config是estimator.RunConfig对象,主要是分布式训练和模型存储相关配置。
- params是传入model_fn的参数。
- warm_start_from是预训练模型路径。