tensorflow eatimator实现early-stopping

相信大家,为了避免过拟合,经常需要用到early-stopping,即在你的loss接近收敛的时候,就可以提前停止训练了。

预备知识

tensorflow estimator详细介绍,实现模型的高效训练

tensorflow通过tfrecord高效读写数据

API介绍

tf.estimator.experimental.stop_if_no_increase_hook(
    estimator, metric_name, max_steps_without_increase, eval_dir=None, min_steps=0,
    run_every_secs=60, run_every_steps=None
)
Args
estimatorA tf.estimator.Estimator instance.
metric_namestr, metric to track. “loss”, “accuracy”, etc.
max_steps_without_increaseint, maximum number of training steps with no increase in the given metric.
eval_dirIf set, directory containing summary files with eval metrics. By default, estimator.eval_dir() will be used.
min_stepsint, stop is never requested if global step is less than this value. Defaults to 0.
run_every_secsIf specified, calls should_stop_fn at an interval of run_every_secs seconds. Defaults to 60 seconds. Either this or run_every_steps must be set.
run_every_stepsIf specified, calls should_stop_fn every run_every_steps steps. Either this or run_every_secs must be set.
  1. estimator:定义你的模型结构,以及训练(train)、验证(evaluate)、预测(predict)过程
  2. metric_name:评判是否要early-stopping的度量
  3. max_steps_without_increase:当评判的度量metric如loss,最多 多少步不下降就early-stopping
  4. min_steps:至少训练多少步,才开始考虑early-stopping
  5. run_every_steps:每n步进行一次early-stopping的评估

说明

  1. 首先,需要明确一点:early-stopping是在验证(evaluate)过程中进行的,所以只能用tf.estimator.train_and_evaluate,并且度量metric_name是针对于验证集eval的,不是训练集;
  2. 整个过程是这样:训练train --> 保存模型 --> 验证evaluate --> 判断是否要early-stopping --> 训练train
  3. 所以,evaluate和early stop的频率实际上是由你模型保存的频率决定的
  4. max_steps_without_increase和run_every_steps的步数是在验证(evaluate)时才计算的,即run_every_steps是指每几次eval就进行early stop的判定,max_steps_without_increase是指evaluate多少次,loss不下降就early-stopping
  5. evaluate过程还可以定义accuracy这样的度量,这种是需要提高的,所以就有对应的tf.estimator.experimental.stop_if_no_decrease_hook

代码

import tensorflow as tf

from estimator import model_fn, input_fn_bulider

# 设置训练多少步就进行模型的保存
runConfig = tf.estimator.RunConfig(save_checkpoints_steps=10)

estimator = tf.estimator.Estimator(model_fn,
                                   model_dir='your_save_path',
                                   config=runConfig,
                                   params={'lr': 0.01})

# 在这里定义一个early-stopping
# 在eval过程执行early-stopping判断,所以评判标准也是eval数据集的metric_name
# max_steps_without_decrease:loss最多多少次不降低就停止。进行一次eval相当于一步。
early_stop = tf.estimator.experimental.stop_if_no_decrease_hook(estimator,
                                                                metric_name='loss',
                                                                max_steps_without_decrease=1,
                                                                run_every_steps=1,
                                                                run_every_secs=None)

logging_hook = tf.train.LoggingTensorHook(every_n_iter=1,
                                          tensors={'loss': 'loss:0'})

# 定义训练(train)过程的数据输入方式
train_input_fn = input_fn_bulider('train.tfrecord', batch_size=1, is_training=True)
# 定义验证(eval)过程的数据输入方式
eval_input_fn = input_fn_bulider('eval.tfrecord', batch_size=1, is_training=False)

# 创建一个TrainSpec实例
train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=100,
                                    hooks=[logging_hook, early_stop])
# 创建一个EvalSpec实例
eval_spec = tf.estimator.EvalSpec(eval_input_fn)

# 流程:训练train --> 保存模型 --> 验证eval --> 判断是否要early-stopping --> 训练train
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值