TensorFlow中CIFAR10的学习

今天学习了下TensorFlow官方网站上CIFAR10的部分,发现有一些API以前没有见过,这里整理了一下。
CIFAR10教程地址

1.首先是一些参数的初始化

FLAGS = tf.app.flags.FLAGS

# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
                            """Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/temp/cifar10_data',
                           """cifar10_inputath to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
                            """Train the model using fp16.""")

TensorFlow 的API中没有找到相关说明,查了一下,发现这个tf.app.flags主要是为了解析命令行传递的参数。

import tensorflow as tf
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('param1', 'default', """string""")
tf.app.flags.DEFINE_bool("param2", False, """bool""")

def main(_):
    print FLAGS.param1
    print FLAGS.param2

if __name__ == "__main__":
    tf.app.run()

如果运行这个文件python test.py,输出的是默认的参数

default
False

如果运行python test.py –param1 test –param2 True 输出为

test
True

2.collection容器的使用
代码里面有个函数tf.add_to_collection(‘losses’, weight_decay)以前没有见过

def _variable_with_weight_decay(name, shape, stddev, wd):
    """Helper to create an initialized Variable with weight decay.
    A weight decay is added only if one is specified.

    Args:
      name: name of the variable
      shape: list of ints
      stddev: standard deviation of a truncated Gaussian
      wd: add L2Loss weight decay multiplied by this float. If None, weight
          decay is not added for this Variable.
    Returns:
      Variable Tensor
    """
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    var = _variable_on_cpu(
        name,
        shape,
        tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)
    )
    if wd is not None:
        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
        tf.add_to_collection('losses', weight_decay)
    return var

tf.add_to_collection(name, value) 是指用默认的图将Graph.add_to_collection()封装起来,也就是说将数据存储在collection中,名字即为参数中的name

Args:

  • name: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
  • value: The value to add to the collection.

之后,可以通过tf.get_collection(name)取出存储过的所有值。

var0 = tf.Variable(tf.constant(1.0))
value = 2 * var0
tf.add_to_collection('value', value)
value = 3 * var0
tf.add_to_collection('value', value)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

test = tf.get_collection("value")
print sess.run(test)

这里输出的结果是

[2.0, 3.0]

3.滑动平均
一些训练算法,如梯度下降法,Momentum法可以通过在优化中进行滑动平均来取得更好的效果。
TensorFlow中提供tf.train.ExponentialMovingAverage。

shadow_variable -= (1 - decay) * (shadow_variable - variable)

也就是

shadow_variable = decay * shadow_variable + (1 - decay) * variable

decay通常选0.999,0.9999等

有两种方法可以实现评估中的滑动平均

  • 创建模型时用shadow variable。average()函数可以返回shadow variable
  • 利用shadow value的名字来加载checkpoint文件,用average_name()函数。

在这里直接将平均后的值储存起来

def _add_loss_summaries(total_loss):
    """Add summaries for losses in CIFAR-10 model.
    Args:
      total_loss: Total loss from loss().
    Returns:
      loss_average_op: op for generating moving averages of losses.
    """
    #Compute the moving average of all individual losses and the total loss.
    loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
    losses = tf.get_collection('losses')
    loss_averages_op = loss_averages.apply(losses + [total_loss])

    # Attach a scalar summary to all individual losses and the total loss;
    # do the same for the averaged version of the losses.
    for l in losses + [total_loss]:
        tf.summary.scalar(l.op.name + '(raw', l)
        tf.summary.scalar(l.op.name, loss_averages.average(l))
    return loss_averages_op

4.控制依赖
代码中还有一个函数tf.control_dependencies(control_inputs)没有见过.这个函数
用默认的图将tf.control_dependencies()封装起来,是指首先执行control_inputs中的对象,在执行下面的操作。

with tf.control_dependencies(loss_averages_op):
    opt = tf.train.GradientDescentOptimizer(lr)
    grads = opt.compute_gradients(total_loss)

5.创建global step
之前我一直是直接创建了一个global step变量来记录步数

global_step = tf.Variable(0, trainable=False)

但在这里,用了tf.contrib类的函数来创建,返回并创建了一个global step变量

tf.contrib.framework.get_or_create_global_step(graph=None)

6.创建MonitoredSession

MonitoredTrainingSession()用于创建MonitoredSession,可以用于创建一些和checkpoint与summary相关的hook。(hook是在训练/评估模型中的工具)
StopAtStepHook:在特定的步数后请求停止,也就是到达FLAGS.max_steps步后停止,
NanTensorHook:在loss是NaN时,监控loss,停止训练
checkpoint_dir:指定存储变量的路径
config:tf.ConfigProto的实例,用于配置session。主要制定设备类型与名字,如“CPU”“GPU”

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    ) as mon_sess:
    while not mon_sess.should_stop():
        mon_sess.run(train_op)

由于训练起来太慢了,所以我这里的步数max_steps: 100000。
之后进行evaluate


evaluation

1.读取checkpoint
tf.train.get_checkpoint_state: 从checkpoint文件中返回Checkpoint状态
这一块跟变量的保存有关了。

class tf.train.Saver用于保存和恢复变量
checkpoint就是一些从变量名字到tensor的映射。

# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

2.线程和队列
队列Queues是TensorFlow中一个异步计算机制,queue也是图中的一个节点,其他的节点可以更改它的内容,此外,节点还可以进行入队enqueue,和出队dequeue操作。

Session是一个多线程的对象,因此多个线程可以使用同一个session,并行地运行ops。TensorFlow中提供了两个类tf.train.Coodinatortf.train.QueueRunner来实现多线程的管理,这两个类必须同时使用。

  • Coordinator类使得多个线程同时停止,且可以汇报exception
  • QueueRunner类用来创建在一些线程,使得它们可以将在同一个队列中进行tensor的入队操作。

首先创建一个Coordinator对象,之后创建使用这个对象的一些线程,这些线程可以在循环中运行,并且在should_stop()返回True时停止。
任何线程都可以使计算停止,只需要运行request_stop(),其他的线程就会在should_stop()返回真时,停止。

QueueRunner类创建了一些可以重复运行入队操作的线程,这些线程可以用同一个Coordinator来控制停止。除此之外,一个queuerunner在意外发生时自动关闭线程。

3.评估
这部分的代码如下:

def evaluate():
    """Eval CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        eval_data = FLAGS.eval_data == 'test'
        images, labels = cifar10.inputs(eval_data=eval_data)

        # Build a Graph that computes the logits predictions from the inference model
        logits = cifar10.inference(images)
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        variable_averages = tf.train.ExponentialMovingAverage(cifar10.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
        while True:
            eval_once(saver, summary_writer, top_k_op, summary_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)

其中用到了一个函数tf.nn.in_top_k(predictions, targets, k, name=None)
来判断targes是否在前k个预测中。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值