概述
接着前面几篇博客讲cifar10的训练,本文牵涉的代码文件主要是cifar10_train.py。
仔细看前面的cifar10.py就会发现,所有变量都被声明放在cpu上,如参考资料里的《卷积神经网络》所说,这是为了多个GPU上共享变量。只有train部分的节点没指定设备,如果有gpu的话训练还是默认在gpu上,其他部分都在cpu上。
代码分解
命令行选项参数
FLAGS = tf.app.flags.FLAGS
# 训练的中间文件将会保存到这个目录下
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
# 训练停止的最大训练步数
tf.app.flags.DEFINE_integer('max_steps', 100000,
"""Number of batches to run.""")
# 是否显示变量的设备分配情况
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
# 记录日志的频率,间隔多少step记录一次
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")
共4个参数,在调用tf.app.run()的时候被解析,此处定义了一个全局的FLAGS来指向命令行配置。
注意其他文件里开头也有定义FLAGS和其他参数,所有的FLAGS都指向tf.app.flags.FLAGS,由tf.app.run()解析这些参数。tf.app.flags.FLAGS作为一个跨文件的全局变量来统一管理所有的输入参数。
日志回调类
cifar10_train.py定义了一个训练的日志回调类_LoggerHook。因为这个文件的训练不是用python的循环手动控制的,而是用tf.train.MonitoredTrainingSession(),需要为这个类填充各种回调类。
_LoggerHook继承自tf.train.SessionRunHook,每次调用session的run之前会调用它的before_run(),run之后会调用after_run()。
_LoggerHook的主要工作在after_run(),主要用于打印loss和用时。
_LoggerHook的before_run()有一个重要的工作就是每次都把loss节点添加到run的队列中,这样整个计算图才会运行起来,没有tf.train.SessionRunArgs(loss)这一行,计算图就是死的。
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
训练会话的管理类
cifar10_train.py用tf.train.MonitoredTrainingSession这个类来管理训练活动。其参数如《tf.train.MonitoredTrainingSession()解析》所讲,我就不重复。
主要讲一下回调类,除了上文的日志回调类,这里还添加了一个按步数停止训练的类和一个loss值为Nan时停止训练的类。这两个类的具体代码是预定义好的,填入参数实例化即可用。
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
参考资料
tensorflow官方讲解cifar10例程的文档:卷积神经网络