一、model_fn
函数有5个输入参数features, labels, mode, params, config,并输出一个EstimatorSpec
实例;
features
:input_fn
的第一个输出。labels
:input_fn
的第二个输出。mode
:操作类型(是训练、预测还是评估),对应tf.estimator.ModeKeys.EVAL/TRAIN/PREDICT
。params
:定义Estimator
实例时传入的params
属性。config
:定义Estimator
实例时传入的config
属性。- 输出
EstimatorSpec
实例介绍:- 训练时:需要指定
loss
和train_op
。 - 预测时:需要指定
predictions
。 - 评估时:需要指定
loss
和metrics
- 训练时:需要指定
二、实例化Estimator
config
参数:- 可用于设置训练过程中相关操作,主要就是summary/save/logging操作。
- 可用于设置
tf.Session
的配置,session_config
其实就是tf.ConfigProto
对象。 - 可用于设置多GPU训练,即
train_distribute
变量。
params
参数:- 主要作用就是可以传入
model_fn
中,帮助实现各类功能。 - 模型参数:以Faster R-CNN为例,可以选择backbone参数,anchors参数,weight decay参数等。
- 训练参数:如优化器类型及参数、学习率参数。
- 性能指标:如选择那些性能指标进行计算等。
- 主要作用就是可以传入
model_fn
model_dir
:summary和save的路径config
:tf.estimator.RunConfig
实例params
:输入参数,会传输到model_fn
中。warm_start_from
:热启动功能,暂时没碰到做啥用的
三、tf.estimator.Estimator
训练、预测、评估
def train(self,
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None):
# input_fn 在`1. 数据集`中介绍
# predict_keys 字符串列表,当EstimatorSpec.predictions是字典时使用
# hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务
# checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件
def predict(self,
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None):
# input_fn 在`1. 数据集`中介绍
# steps 评估次数最大值
# hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务
# checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件
# name 名称,好像用于记录不同数据集上的结果,将评估结果保存到不同文件夹中
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
四、tf.data
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
参数说明:
batch:指更新梯度中使用的样本数;
repeat:将数据重复多次,主要用来处理epoch;
shuffle:打乱dataset中的元素;
map:Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的DataSet;