tensorflow 18:cifar10图片分类训练(cpu或单gpu)

65 篇文章 5 订阅
59 篇文章 6 订阅

概述

接着前面几篇博客讲cifar10的训练,本文牵涉的代码文件主要是cifar10_train.py。

仔细看前面的cifar10.py就会发现,所有变量都被声明放在cpu上,如参考资料里的《卷积神经网络》所说,这是为了多个GPU上共享变量。只有train部分的节点没指定设备,如果有gpu的话训练还是默认在gpu上,其他部分都在cpu上。

代码分解

命令行选项参数

FLAGS = tf.app.flags.FLAGS

# 训练的中间文件将会保存到这个目录下
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
# 训练停止的最大训练步数
tf.app.flags.DEFINE_integer('max_steps', 100000,
                            """Number of batches to run.""")
# 是否显示变量的设备分配情况
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
# 记录日志的频率,间隔多少step记录一次
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")

共4个参数,在调用tf.app.run()的时候被解析,此处定义了一个全局的FLAGS来指向命令行配置。

注意其他文件里开头也有定义FLAGS和其他参数,所有的FLAGS都指向tf.app.flags.FLAGS,由tf.app.run()解析这些参数。tf.app.flags.FLAGS作为一个跨文件的全局变量来统一管理所有的输入参数。

日志回调类

cifar10_train.py定义了一个训练的日志回调类_LoggerHook。因为这个文件的训练不是用python的循环手动控制的,而是用tf.train.MonitoredTrainingSession(),需要为这个类填充各种回调类。

_LoggerHook继承自tf.train.SessionRunHook,每次调用session的run之前会调用它的before_run(),run之后会调用after_run()。

_LoggerHook的主要工作在after_run(),主要用于打印loss和用时。
_LoggerHook的before_run()有一个重要的工作就是每次都把loss节点添加到run的队列中,这样整个计算图才会运行起来,没有tf.train.SessionRunArgs(loss)这一行,计算图就是死的。

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

训练会话的管理类

cifar10_train.py用tf.train.MonitoredTrainingSession这个类来管理训练活动。其参数如《tf.train.MonitoredTrainingSession()解析》所讲,我就不重复。

主要讲一下回调类,除了上文的日志回调类,这里还添加了一个按步数停止训练的类和一个loss值为Nan时停止训练的类。这两个类的具体代码是预定义好的,填入参数实例化即可用。

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)

参考资料

tensorflow官方讲解cifar10例程的文档:卷积神经网络

tf.train.MonitoredTrainingSession()解析

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值