Tensorflow---monitoring,saver

本文介绍了在TensorFlow中如何进行训练监控和模型保存。通过tf.Saver可以方便地保存和恢复变量。另外,详细讲解了tf.Supervisor和tf.train.MonitoredTrainingSession的用法,包括它们的自动保存策略、summary ops的运行以及训练钩子(hooks)的应用。这两种方法为训练过程提供了灵活性和控制,允许从检查点恢复模型,并自定义保存和监控频率。
摘要由CSDN通过智能技术生成

这里主要介绍两种方式来监视训练。

在介绍之前先讲一下 tf.saver 。这个可以用来保存变量。使用方法如下,可以保存所有变量,也可以保存部分指定变量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# If you want to save some variables not all of them
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")


恢复变量用下面的方式:

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")


1. tf.Supervisor

  ...create graph...
  my_train_op = ...

  sv = tf.Supervisor(logdir="/my/training/directory")
  with sv.managed_session() as sess:
    for step in range(100000):
      if sv.should_stop():
        break
      sess.run(my_train_op)

tf.Supervisor 是一个比较简单的方法。以上面为例,它每10分钟(default)向 'logdir' 内保存图内的 vars ,并且它每2分钟(default)自动运行 所有的 summary ops,同时把 event file 存进 'logdir' 。他还自动记录 steps , 在它自己的线程里启动  tf.train.QueueRunner 。也可以用它来检查停止点 tf.Supervisor.should_stop() ,要使用这个的话需要自己来设定 stop creterion ,然后当满足该条件时再用 tf.Supervisor.request_stop() 来触发 tf.Supervisor.should_stop() ,下一次检查 should_stop() 时就会停下。

 

如果你不想让它自动保存 summary (过频繁的保存会使内存溢出),用下面的方法来自己维护 summary:

  ...create graph...
  my_train_op = ...
  my_summary_op = tf.summary.merge_all()

  sv = tf.Supervisor(logdir="/my/training/directory",
                     summary_op=None) # Do not run the summary service
  with sv.managed_session() as sess:
    for step in range(100000):
      if sv.should_stop():
        break
      if step % 100 == 0:
        _, summ = session.run([my_train_op, my_summary_op])
        sv.summary_computed(sess, summ)
      else:
        session.run(my_train_op)


当然也可以从 check_point 文件中恢复模型。可以用预先训练好的模型的参数来初始化自己模型的变量。把 tf.Supervisor() 中的 'init_fn' 参数用下面的方法初始化:

  ...create graph...
  # Create a saver that restores only the pre-trained variables.
  pre_train_saver = tf.Saver([pre_train_var1, pre_train_var2])

  # Define an init function that loads the pretrained checkpoint.
  def load_pretrain(sess):
    pre_train_saver.restore(sess, "<path to pre-trained-checkpoint>")

  # Pass the init function to the supervisor.
  #
  # The init function is called _after_ the variables have been initialized
  # by running the init_op.
  sv = tf.Supervisor(logdir="/my/training/directory",
                     init_fn=load_pretrain)
  with sv.managed_session() as sess:
    # Here sess was either initialized from the pre-trained-checkpoint or
    # recovered from a checkpoint saved in a previous run of this code.
    ...

也可以设定 checkpoint_basename/save_model_secs/saver/save_summaries_secs/summary_op 等等。比如:

  ...create graph...
  my_saver = tf.Saver(<only some variables>)
  sv = tf.Supervisor(logdir="/my/training/directory",
                     saver=my_saver,
                     save_model_secs=30)
  with sv.managed_session() as sess:
    ...training loop...

也可以设定自己的 service:

def my_additional_sumaries(sv, sess):
  summaries = sess.run(my_additional_summary_op)
  sv.summary_computed(sess, summaries)

...
  sv = tf.Supervisor(logdir="/my/training/directory")
  with sv.managed_session() as sess:
    # Call my_additional_sumaries() every 1200s, or 20mn,
    # passing (sv, sess) as arguments.
    sv.loop(1200, my_additional_sumaries, args=(sv, sess))
    ...main training loop...
更多细节查看这里这里


2. tf.train.MonitoredTrainingSession()

下面的方法是取自 CIFAR10 。里面用到 training hook ,具体这是个什么东西我还真不清楚,没见到过,但是还是模仿一下用法。tf.train.MonitoredTrainingSession() 的声明还真有点长。。。

tf.train.MonitoredTrainingSession(master='', is_chief=true, checkpoint_dir=none, scaffold=none, kooks=none, chief_only_hooks=none, save_checkpoint_secs=600, save_summaries_steps=100, config=none)

checkpoint_dir: 输出 checkpoint 地址

hooks: 'a list of SessionRunHook objects' 。按照我的理解,他确实是实现某些功能,并且 hook 类的定义看起来都是有固定格式的,后面会给出

save_checkpoint_secs/save_summaries_steps: 看名字很容易理解

config: 'an instance of tf.ConfigProto' 。用来配置 session ,可以查看这里 tf.ConfigProto ,查看 CPU/GPU 的情况。


下面是使用的例子:

   with tf.train.MonitoredTrainingSession(  
        checkpoint_dir=FLAGS.train_dir,  
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),  
               tf.train.NanTensorHook(loss),  
               _LoggerHook()],  
        config=tf.ConfigProto(  
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:  
      while not mon_sess.should_stop():  
        mon_sess.run(train_op) 

看起来这个用起来比 Supervisor 更紧凑一点。Tensorflow 里有很多与定义好的 hook 可以直接使用。这里来查其介绍,这里来查其对应代码。也可以自己建造 hook 类,它里面应该定义 begin() , before_run() , after_run() 这三个函数:

    class _LoggerHook(tf.train.SessionRunHook):  
      """Logs loss and runtime."""  
  
      def begin(self):  
        self._step = -1  
        self._start_time = time.time()  
  
      def before_run(self, run_context):  
        self._step += 1  
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.  
  
      def after_run(self, run_context, run_values):  
        if self._step % FLAGS.log_frequency == 0:  
          current_time = time.time()  
          duration = current_time - self._start_time  
          self._start_time = current_time  
  
          loss_value = run_values.results  
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration  
          sec_per_batch = float(duration / FLAGS.log_frequency)  
  
          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '  
                        'sec/batch)')  
          print (format_str % (datetime.now(), self._step, loss_value,  
                               examples_per_sec, sec_per_batch))  

'run_context' 的意义应该猜的出来。before_run() 返回 ‘run_value’ ,这个是用来作为 after_run() 的输入来得到需要的值的。 tf.train.SessionRunArgs(some tensor) 实际上获取 tensor 的值,然后在 after_run() 里面用 run_values.results 来调用。 Interesting!




















评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值