相信大家,为了避免过拟合,经常需要用到early-stopping
,即在你的loss接近收敛的时候,就可以提前停止训练了。
预备知识
tensorflow estimator详细介绍,实现模型的高效训练
API介绍
tf.estimator.experimental.stop_if_no_increase_hook(
estimator, metric_name, max_steps_without_increase, eval_dir=None, min_steps=0,
run_every_secs=60, run_every_steps=None
)
Args | |
---|---|
estimator | A tf.estimator.Estimator instance. |
metric_name | str , metric to track. “loss”, “accuracy”, etc. |
max_steps_without_increase | int , maximum number of training steps with no increase in the given metric. |
eval_dir | If set, directory containing summary files with eval metrics. By default, estimator.eval_dir() will be used. |
min_steps | int , stop is never requested if global step is less than this value. Defaults to 0. |
run_every_secs | If specified, calls should_stop_fn at an interval of run_every_secs seconds. Defaults to 60 seconds. Either this or run_every_steps must be set. |
run_every_steps | If specified, calls should_stop_fn every run_every_steps steps. Either this or run_every_secs must be set. |
- estimator:定义你的模型结构,以及训练(train)、验证(evaluate)、预测(predict)过程
- metric_name:评判是否要early-stopping的度量
- max_steps_without_increase:当评判的度量metric如loss,最多 多少步不下降就early-stopping
- min_steps:至少训练多少步,才开始考虑early-stopping
- run_every_steps:每n步进行一次early-stopping的评估
说明
- 首先,需要明确一点:early-stopping是在验证(evaluate)过程中进行的,所以只能用
tf.estimator.train_and_evaluate
,并且度量metric_name是针对于验证集eval的,不是训练集; - 整个过程是这样:训练train --> 保存模型 --> 验证evaluate --> 判断是否要early-stopping --> 训练train
- 所以,evaluate和early stop的频率实际上是由你模型保存的频率决定的
- max_steps_without_increase和run_every_steps的步数是在验证(evaluate)时才计算的,即run_every_steps是指每几次eval就进行early stop的判定,max_steps_without_increase是指evaluate多少次,loss不下降就early-stopping
- evaluate过程还可以定义accuracy这样的度量,这种是需要提高的,所以就有对应的
tf.estimator.experimental.stop_if_no_decrease_hook
代码
import tensorflow as tf
from estimator import model_fn, input_fn_bulider
# 设置训练多少步就进行模型的保存
runConfig = tf.estimator.RunConfig(save_checkpoints_steps=10)
estimator = tf.estimator.Estimator(model_fn,
model_dir='your_save_path',
config=runConfig,
params={'lr': 0.01})
# 在这里定义一个early-stopping
# 在eval过程执行early-stopping判断,所以评判标准也是eval数据集的metric_name
# max_steps_without_decrease:loss最多多少次不降低就停止。进行一次eval相当于一步。
early_stop = tf.estimator.experimental.stop_if_no_decrease_hook(estimator,
metric_name='loss',
max_steps_without_decrease=1,
run_every_steps=1,
run_every_secs=None)
logging_hook = tf.train.LoggingTensorHook(every_n_iter=1,
tensors={'loss': 'loss:0'})
# 定义训练(train)过程的数据输入方式
train_input_fn = input_fn_bulider('train.tfrecord', batch_size=1, is_training=True)
# 定义验证(eval)过程的数据输入方式
eval_input_fn = input_fn_bulider('eval.tfrecord', batch_size=1, is_training=False)
# 创建一个TrainSpec实例
train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=100,
hooks=[logging_hook, early_stop])
# 创建一个EvalSpec实例
eval_spec = tf.estimator.EvalSpec(eval_input_fn)
# 流程:训练train --> 保存模型 --> 验证eval --> 判断是否要early-stopping --> 训练train
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)