详细教程:https://www.tensorflow.org/get_started/custom_estimators
预创建的 Estimator 是 tf.estimator.Estimator 基类的子类,而自定义 Estimator 是 tf.estimator.Estimator 的实例:
模型函数(即 model_fn)会实现机器学习算法。采用预创建的 Estimator 和自定义 Estimator 的唯一区别是:
- 如果采用预创建的 Estimator,则有人已为您编写了模型函数。
- 如果采用自定义 Estimator,则您必须自行编写模型函数。
步骤:
- 编写输入函数
- 创建特征列
- 编写模型函数
- 定义模型
- 实现训练、评估和预测
#1. 编写输入函数
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
此输入函数会构建可以生成批次 (features, labels) 对的输入管道,其中 features 是字典特征。
#2. 创建特征列
定义模型的特征列来指定模型应该如何使用每个特征。
以下代码为每个输入特征创建一个简单的 numeric_column,表示应该将输入特征的值直接用作模型的输入:
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
#3. 编写模型函数
我们要使用的模型函数具有以下调用签名:
def my_model_fn(
features, # This is batch_features from input_fn
labels, # This is batch_labels from