在TensorFlow 中使用hooks实现Early_Stopping

在 这篇博客 中训练CNN的时候,即便是对fc层加了dropout,对loss加了L2正则化,依然出现了过拟合的情况(如下图所示),于是开始尝试用early stop解决拟合问题。

(训练集的loss在下降而测试集的loss却在5k步左右开始上升,说明过拟合了) 

想要实现ES,首先需要知道loss的值(以便根据loss值在xx次迭代内的变化决定是否需要停止training),Tensorflwo中提供了hook来帮助我们从graph中“钩取”出想要的值。

这里用到的hook是 tf.contrib.estimator.stop_if_no_decrease_hook ,它的源码可以 看这里 

tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name,
    max_steps_without_decrease,
    eval_dir=None,
    min_steps=0,
    run_every_secs=60,
    run_every_steps=None
)

其中metric_name用来指明监控的变量(比如loss或者accuracy),在这篇博文中的run_mnist()代码如下:

def run_mnist(params):
    model_helpers.apply_clean(params)  # 清空model_dir文件夹下的旧文件

        #实例化estimator
    paramsdic = params.flag_values_dict()
    model = tf.estimator.Estimator(model_fn=cnn_model_fn,model_dir=params.model_dir,params=paramsdic) #Estimator的构造函数会把params传给model_fn
    #为啥不能params=params??因为传入的params是一个类!!!absl.flags._flagvalues.FlagValues类,需要调用函数flag_values_dict()将他的属性转化成dic才能被传入model_fn
    #没转化成dic时用params.dropout_rate   代表取出属性
    #转化成dic后用params['dropout_rate']  #代表取出key对应的value

        #实例化hooks(用于监控台输出程序运行的记录日志,记录哪些量由tensor_to_log字典给出)而tensorboard的图似乎和hook没关系?
    tensor_to_log={'prob':'softmax_tensor'}#打印prob,其值来源于softmax_tensor
    train_hooks = hooks_helper.get_train_hooks(name_list=params.hooks,model_dir=params.model_dir,)#tensors_to_log=tensor_to_log)

    os.makedirs(model.eval_dir())
    train_hoooks_for_earlyStoping = stop_if_no_decrease_hook(model,eval_dir=model.eval_dir(),metric_name='accuracy',max_steps_without_decrease=1000,min_steps=100)
    #必须使用loss而不是eval_loss,因为train里自动记录的是名字为‘loss’的值
        #input_fn函数
    def train_input_fn():#这里虽然返回的是一个ds但是实际上这个是被zip(feature,label)的ds,可以直接被parse成feature,label [也就是 model.train中需要input_fn返回的形式]
        ds = dataset.train(params.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(params.batch_size)
        ds = ds.repeat(params.epochs_between_evals)
        return ds
    def eval_input_fn():
        return dataset.test(params.data_dir).batch(
            params.batch_size).make_one_shot_iterator().get_next()
        #每次返回一个(fea,lab)对??
        #为啥eval的input返回的是迭代器而train的input返回的是整个的dataset??

        #train和eval
    for i in range(params.train_epochs // params.epochs_between_evals):
        # tf.estimator.train_and_evaluate(model,train_spec=tf.estimator.TrainSpec(train_input_fn,hooks=[train_hoooks_for_earlyStoping]),
        #                                 eval_spec=tf.estimator.EvalSpec(eval_input_fn))
        model.train(input_fn=train_input_fn,hooks=[train_hoooks_for_earlyStoping])# 如果这里参数传入了 hooks=train_hooks 那么model_fn中的train就要把注释的几个identity解开
        if train_hoooks_for_earlyStoping.stopFlag == True :
            break
        eval_results = model.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(params.stop_threshold,
                                             eval_results['accuracy']):
            break

其中 定义ES的hook语句如下所示,他返回的是一个实例化的hook:

train_hoooks_for_earlyStoping = stop_if_no_decrease_hook(model,eval_dir=model.eval_dir(),metric_name='accuracy',max_steps_without_decrease=1000,min_steps=100)

由于只是将这个hook传给了model.train(),所以只能让train停止,此后会接着执行model.eval()以及进入下一次循环,因此并没有真正起到EarlyStopping的作用(整个程序停止,这里只是让model.train()停止),所以需要对stop_if_no_decrease_hook的源码进行修改,为这个类增加一个属性,用来标识是否开始ES:

修改的源码如下所示:

class _StopOnPredicateHook(session_run_hook.SessionRunHook):
  """Hook that requests stop when `should_stop_fn` returns `True`."""

  def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):
    if not callable(should_stop_fn):
      raise TypeError('`should_stop_fn` must be callable.')
    
#增加这个tag !!!!!
    self.stopFlag = False 
#增加这个tag !!!!!
    self._should_stop_fn = should_stop_fn
    self._timer = basic_session_run_hooks.SecondOrStepTimer(
        every_secs=run_every_secs, every_steps=run_every_steps)
    self._global_step_tensor = None
    self._stop_var = None
    self._stop_op = None

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    self._stop_var = _get_or_create_stop_var()
    self._stop_op = state_ops.assign(self._stop_var, True)

  def before_run(self, run_context):
    del run_context
    return session_run_hook.SessionRunArgs(self._global_step_tensor)

  def after_run(self, run_context, run_values):
    global_step = run_values.results
    if self._timer.should_trigger_for_step(global_step):
      self._timer.update_last_triggered_step(global_step)
      if self._should_stop_fn():
        self.stopFlag = True
        tf_logging.info('Requesting early stopping at global step %d',
                        global_step)
        run_context.session.run(self._stop_op)
        run_context.request_stop()


class _CheckForStoppingHook(session_run_hook.SessionRunHook):
  """Hook that requests stop if stop is requested by `_StopOnPredicateHook`."""

  def __init__(self):
    self._stop_var = None
#增加这个tag !!!!!
    self.stopFlag = False 
#增加这个tag !!!!!

  def begin(self):
    self._stop_var = _get_or_create_stop_var()

  def before_run(self, run_context):
    del run_context
    return session_run_hook.SessionRunArgs(self._stop_var)

  def after_run(self, run_context, run_values):
    should_early_stop = run_values.results
    if should_early_stop:
      self.stopFlag = True
      tf_logging.info('Early stopping requested, suspending run.')
      run_context.request_stop()

这样的话在loop的循环中:

    for i in range(params.train_epochs // params.epochs_between_evals):
        # tf.estimator.train_and_evaluate(model,train_spec=tf.estimator.TrainSpec(train_input_fn,hooks=[train_hoooks_for_earlyStoping]),
        #                                 eval_spec=tf.estimator.EvalSpec(eval_input_fn))
        model.train(input_fn=train_input_fn,hooks=[train_hoooks_for_earlyStoping])# 如果这里参数传入了 hooks=train_hooks 那么model_fn中的train就要把注释的几个identity解开
        if train_hoooks_for_earlyStoping.stopFlag == True :
            break
        eval_results = model.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(params.stop_threshold,
                                             eval_results['accuracy']):
            break

model.train()之后立刻验证新增加的tag的值,如果是true,说明上面的model.train()不是正常结束的,而是由于ES结束的,此时立刻跳出循环,结束整个训练和测试loop。

最后得到的结果如下:

 

可以看到大约4k步之后,hook检测到eval的loss很久没有降低了,于是进行了ES(没有ES时会迭代25k步左右,见上图)

参考链接:

Implement early stopping in tf.estimator.DNNRegressor using the available training hooks

Early stopping with tf.estimator, how?

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值