源码解析
上一篇博客,实例化estimator的代码如下:
tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig
)
我们从这个实例化进入,看我们需要传给estimator的参数都是些什么?上面三个代码不是全部参数,看源码
class Estimator(object):
def __init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None):
...
可以看到需要传入的参数如下:
model_dir: 指定checkpoints和其他日志存放的路径。
model_fn: 这个是需要我们自定义的网络模型函数,后面详细介绍
config: 用于控制内部和checkpoints等,如果model_fn函数也定义config这个变量,则会将config传给model_fn
params: 该参数的值会传递给model_fn。
warm_start_from: 指定checkpoint路径,会导入该checkpoint开始训练
其中最重要的就是model_fn,params是服务于这个模型。
model_fn
这个接受的是我们自定义的模型,下面看一下它的主要参数。
模型的定义一般如下:
def my_model_fn(
features, # This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).
labels, # This is batch_labels from input_fn
mode, # An instance of tf.estimator.ModeKeys
params, # Additional configuration
config=None
):
- 前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。
- params
是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params[‘n_classes’]来定义最终输出节点的个数等。 - config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
- mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL /
PREDICT
来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(…)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN。
model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个tf.estimator.EstimatorSpec的实例。
如下实例:
def my_model_fn(features,labels,mode,params):
#输入层,feature_columns对应Classifier(feature_columns=...)
net