TensorFlow 版本:1.10.0 > Guide > Introduction to Estimators
Estimator 概述
本篇将介绍 TensorFlow 中的 Estimators(可极大简化机器学习编程过程)。Estimators 中封装了以下几部分:
- 训练(training)
- 评估(evaluation)
- 预测(prediction)
- 输出模型(export for serving)
我们既可以使用内置 Estimator,也可以编写自定义 Estimator。
注意:TensorFlow 中的
tf.contrib.learn.Estimator
已经弃用了,请不要使用该 API。
文章目录
1. Estimator 的优势
Estimator 有以下优势:
- 对分布式的良好支持(不需要更改代码)。
- 有利于模型开发者之间的代码分享。
- 简化了模型的创建工作。
- Estimator 建立在 `tf.layers` 上,这简化了自定义 Estimator 的编写。
- Estimator 会为你创建 graph。
Estimator 提供了一个安全的分布式训练环境,其会帮我们控制这么、何时去:
- 建立 graph。
- 初始化 variables。
- 开始 queues。
- 处理 exceptions。
- 创建 checkpoint 文件,从失败中恢复训练。
- 保存 summaries for TensorBoard。
当用 Estimator 编写一个 application,你必须将 input pipeline 和 model 分开。这种分离简化了在不同数据集上的 experiments。
2. 内置的 Estimator
内置的 Estimator 使得你可以在更高层面思考问题。内置的 Estimator 会为你创建、管理 Graph 和 Session 对象。另外,内置的 Estimator 使你可以在最小的代码修改量的情况下实验不同的模型结构。
2.1 基于内置 Estimator 的程序的结构
使用内置 Estimator 的 TF 程序一般包含以下四步:
编写一个或多个数据集导入函数。 例如:创建一个函数来导入训练数据集,另一个函数来导入测试数据集。每一个数据集导入函数必须返回两个对象:
- 一个字典。字典的键名为特征的名字,键值为 表示特征数据的 Tensor 或 Sparse Tensor。
- 一个 Tensor。该 Tensor 包含一个或多个 label。
例如,下面的代码说明了 input function 的基本框架:
def input_fn(dataset): ... # m