https://www.cnblogs.com/zongfa/p/10149483.html
Estimator是一个可极大地简化机器学习编程的高阶API,它封装了下列操作:
训练, 评估,预测, 导出以供使用
Estimator本身在tf.layers之上构建而成,可以简化自定义过程
使用Estimator编写应用时,必须将数据输入管道从模型中分离出来,这种分离简化了不同数据集的试验流程
可以将现有的keras模型转换成Estimator,这样做之后,keras模型就可以利用Estimator的优势,例如分布式训练,调用
tf.keras.estimator.model_to_estimator
,如下例所示
# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)
# Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"input_1": train_data},
y=train_labels,
num_epochs=1,
shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)