tf.estimator.Estimator
Estimator class训练和测试TF模型。Estimator
对象封装好通过model_fn
指定的模型,给定输入和其它超参数,返回ops执行training, evaluation or prediction. 所有的输出(包含checkpoints, event files, etc.)被写入model_dir
。
属性
- config
传入
model_fn
,如果model_fn
有参数named “config” - model_dir
- model_fn
The model_fn with following signature:def model_fn(features, labels, mode, config)
- params
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None # 将要传入model_fn的超参数字典
)
evaluate
对训练模型评价
evaluate(
input_fn, # 输入函数,返回元组features和labels
steps=None,
hooks=None, # List of SessionRunHook subclass instances
checkpoint_path=None, # if none, 用model_dir中latest checkpoint
name=None
)
export_savemodel
导出inference graph作为一个SavedModel
export_savedmodel(
export_dir_base, # 目录
serving_input_receiver_fn, # 返回ServingInputReceiver的函数
assets_extra=None,
as_text=False,
checkpoint_path=None
)
-
get_variable_names
get_variable_names()
返回模型中所有变量名字的列表 -
get_variable_value(name)
根据变量name返回value -
latest_checkpoint()
在model_dir
中找到最近保存的checkpoint -
predict
根据给定的features产生预测
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
- train
给定训练数据后训练model
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
转自:https://blog.csdn.net/MAJUN1259389904/article/details/79340547