这里主要介绍两种方式来监视训练。
在介绍之前先讲一下 tf.saver 。这个可以用来保存变量。使用方法如下,可以保存所有变量,也可以保存部分指定变量。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# If you want to save some variables not all of them
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
恢复变量用下面的方式:
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
1. tf.Supervisor
...create graph...
my_train_op = ...
sv = tf.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
sess.run(my_train_op)
tf.Supervisor 是一个比较简单的方法。以上面为例,它每10分钟(default)向 'logdir' 内保存图内的 vars ,并且它每2分钟(default)自动运行 所有的 summary ops,同时把 event file 存进 'logdir' 。他还自动记录 steps , 在它自己的线程里启动 tf.train.QueueRunner 。也可以用它来检查停止点 tf.Supervisor.should_stop() ,要使用这个的话需要自己来设定 stop creterion ,然后当满足该条件时再用 tf.Supervisor.request_stop() 来触发 tf.Supervisor.should_stop() ,下一次检查 should_stop() 时就会停下。
如果你不想让它自动保存 summary (过频繁的保存会使内存溢出),用下面的方法来自己维护 summary:
...create graph...
my_train_op = ...
my_summary_op = tf.summary.merge_all()
sv = tf.Supervisor(logdir="/my/training/directory",
summary_op=None) # Do not run the summary service
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
if step % 100 == 0:
_, summ = session.run([my_train_op, my_summary_op])
sv.summary_computed(sess, summ)
else:
session.run(my_train_op)
当然也可以从 check_point 文件中恢复模型。可以用预先训练好的模型的参数来初始化自己模型的变量。把 tf.Supervisor() 中的 'init_fn' 参数用下面的方法初始化:
...create graph...
# Create a saver that restores only the pre-trained variables.
pre_train_saver = tf.Saver([pre_train_var1, pre_train_var2])
# Define an init function that loads the pretrained checkpoint.
def load_pretrain(sess):
pre_train_saver.restore(sess, "<path to pre-trained-checkpoint>")
# Pass the init function to the supervisor.
#
# The init function is called _after_ the variables have been initialized
# by running the init_op.
sv = tf.Supervisor(logdir="/my/training/directory",
init_fn=load_pretrain)
with sv.managed_session() as sess:
# Here sess was either initialized from the pre-trained-checkpoint or
# recovered from a checkpoint saved in a previous run of this code.
...
也可以设定 checkpoint_basename/save_model_secs/saver/save_summaries_secs/summary_op 等等。比如:
...create graph...
my_saver = tf.Saver(<only some variables>)
sv = tf.Supervisor(logdir="/my/training/directory",
saver=my_saver,
save_model_secs=30)
with sv.managed_session() as sess:
...training loop...
也可以设定自己的 service:
def my_additional_sumaries(sv, sess):
summaries = sess.run(my_additional_summary_op)
sv.summary_computed(sess, summaries)
...
sv = tf.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
# Call my_additional_sumaries() every 1200s, or 20mn,
# passing (sv, sess) as arguments.
sv.loop(1200, my_additional_sumaries, args=(sv, sess))
...main training loop...
更多细节查看这里和这里。
2. tf.train.MonitoredTrainingSession()
下面的方法是取自 CIFAR10 。里面用到 training hook ,具体这是个什么东西我还真不清楚,没见到过,但是还是模仿一下用法。tf.train.MonitoredTrainingSession() 的声明还真有点长。。。
tf.train.MonitoredTrainingSession(master='', is_chief=true, checkpoint_dir=none, scaffold=none, kooks=none, chief_only_hooks=none, save_checkpoint_secs=600, save_summaries_steps=100, config=none)
checkpoint_dir: 输出 checkpoint 地址
hooks: 'a list of SessionRunHook objects' 。按照我的理解,他确实是实现某些功能,并且 hook 类的定义看起来都是有固定格式的,后面会给出
save_checkpoint_secs/save_summaries_steps: 看名字很容易理解
config: 'an instance of tf.ConfigProto' 。用来配置 session ,可以查看这里 tf.ConfigProto ,查看 CPU/GPU 的情况。
下面是使用的例子:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
看起来这个用起来比 Supervisor 更紧凑一点。Tensorflow 里有很多与定义好的 hook 可以直接使用。这里来查其介绍,这里来查其对应代码。也可以自己建造 hook 类,它里面应该定义 begin() , before_run() , after_run() 这三个函数:
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
'run_context' 的意义应该猜的出来。before_run() 返回 ‘run_value’ ,这个是用来作为 after_run() 的输入来得到需要的值的。 tf.train.SessionRunArgs(some tensor) 实际上获取 tensor 的值,然后在 after_run() 里面用 run_values.results 来调用。 Interesting!