今天学习了下TensorFlow官方网站上CIFAR10的部分,发现有一些API以前没有见过,这里整理了一下。
CIFAR10教程地址
1.首先是一些参数的初始化
FLAGS = tf.app.flags.FLAGS
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/temp/cifar10_data',
"""cifar10_inputath to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
"""Train the model using fp16.""")
TensorFlow 的API中没有找到相关说明,查了一下,发现这个tf.app.flags主要是为了解析命令行传递的参数。
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('param1', 'default', """string""")
tf.app.flags.DEFINE_bool("param2", False, """bool""")
def main(_):
print FLAGS.param1
print FLAGS.param2
if __name__ == "__main__":
tf.app.run()
如果运行这个文件python test.py,输出的是默认的参数
default
False
如果运行python test.py –param1 test –param2 True 输出为
test
True
2.collection容器的使用
代码里面有个函数tf.add_to_collection(‘losses’, weight_decay)以前没有见过
def _variable_with_weight_decay(name, shape, stddev, wd):
"""Helper to create an initialized Variable with weight decay.
A weight decay is added only if one is specified.
Args:
name: name of the variable
shape: list of ints
stddev: standard deviation of a truncated Gaussian
wd: add L2Loss weight decay multiplied by this float. If None, weight
decay is not added for this Variable.
Returns:
Variable Tensor
"""
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
var = _variable_on_cpu(
name,
shape,
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)
)
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
tf.add_to_collection(name, value) 是指用默认的图将Graph.add_to_collection()封装起来,也就是说将数据存储在collection中,名字即为参数中的name
Args:
- name: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
- value: The value to add to the collection.
之后,可以通过tf.get_collection(name)取出存储过的所有值。
var0 = tf.Variable(tf.constant(1.0))
value = 2 * var0
tf.add_to_collection('value', value)
value = 3 * var0
tf.add_to_collection('value', value)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
test = tf.get_collection("value")
print sess.run(test)
这里输出的结果是
[2.0, 3.0]
3.滑动平均
一些训练算法,如梯度下降法,Momentum法可以通过在优化中进行滑动平均来取得更好的效果。
TensorFlow中提供tf.train.ExponentialMovingAverage。
shadow_variable -= (1 - decay) * (shadow_variable - variable)
也就是
shadow_variable = decay * shadow_variable + (1 - decay) * variable
decay通常选0.999,0.9999等
有两种方法可以实现评估中的滑动平均
- 创建模型时用shadow variable。average()函数可以返回shadow variable
- 利用shadow value的名字来加载checkpoint文件,用average_name()函数。
在这里直接将平均后的值储存起来
def _add_loss_summaries(total_loss):
"""Add summaries for losses in CIFAR-10 model.
Args:
total_loss: Total loss from loss().
Returns:
loss_average_op: op for generating moving averages of losses.
"""
#Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
losses = tf.get_collection('losses')
loss_averages_op = loss_averages.apply(losses + [total_loss])
# Attach a scalar summary to all individual losses and the total loss;
# do the same for the averaged version of the losses.
for l in losses + [total_loss]:
tf.summary.scalar(l.op.name + '(raw', l)
tf.summary.scalar(l.op.name, loss_averages.average(l))
return loss_averages_op
4.控制依赖
代码中还有一个函数tf.control_dependencies(control_inputs)没有见过.这个函数
用默认的图将tf.control_dependencies()封装起来,是指首先执行control_inputs中的对象,在执行下面的操作。
with tf.control_dependencies(loss_averages_op):
opt = tf.train.GradientDescentOptimizer(lr)
grads = opt.compute_gradients(total_loss)
5.创建global step
之前我一直是直接创建了一个global step变量来记录步数
global_step = tf.Variable(0, trainable=False)
但在这里,用了tf.contrib类的函数来创建,返回并创建了一个global step变量
tf.contrib.framework.get_or_create_global_step(graph=None)
6.创建MonitoredSession
MonitoredTrainingSession()用于创建MonitoredSession,可以用于创建一些和checkpoint与summary相关的hook。(hook是在训练/评估模型中的工具)
StopAtStepHook:在特定的步数后请求停止,也就是到达FLAGS.max_steps步后停止,
NanTensorHook:在loss是NaN时,监控loss,停止训练
checkpoint_dir:指定存储变量的路径
config:tf.ConfigProto的实例,用于配置session。主要制定设备类型与名字,如“CPU”“GPU”
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
由于训练起来太慢了,所以我这里的步数max_steps: 100000。
之后进行evaluate
evaluation
1.读取checkpoint
tf.train.get_checkpoint_state: 从checkpoint文件中返回Checkpoint状态
这一块跟变量的保存有关了。
class tf.train.Saver用于保存和恢复变量
checkpoint就是一些从变量名字到tensor的映射。
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
2.线程和队列
队列Queues是TensorFlow中一个异步计算机制,queue也是图中的一个节点,其他的节点可以更改它的内容,此外,节点还可以进行入队enqueue,和出队dequeue操作。
Session是一个多线程的对象,因此多个线程可以使用同一个session,并行地运行ops。TensorFlow中提供了两个类tf.train.Coodinator和tf.train.QueueRunner来实现多线程的管理,这两个类必须同时使用。
- Coordinator类使得多个线程同时停止,且可以汇报exception
- QueueRunner类用来创建在一些线程,使得它们可以将在同一个队列中进行tensor的入队操作。
首先创建一个Coordinator对象,之后创建使用这个对象的一些线程,这些线程可以在循环中运行,并且在should_stop()返回True时停止。
任何线程都可以使计算停止,只需要运行request_stop(),其他的线程就会在should_stop()返回真时,停止。
QueueRunner类创建了一些可以重复运行入队操作的线程,这些线程可以用同一个Coordinator来控制停止。除此之外,一个queuerunner在意外发生时自动关闭线程。
3.评估
这部分的代码如下:
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
eval_data = FLAGS.eval_data == 'test'
images, labels = cifar10.inputs(eval_data=eval_data)
# Build a Graph that computes the logits predictions from the inference model
logits = cifar10.inference(images)
top_k_op = tf.nn.in_top_k(logits, labels, 1)
variable_averages = tf.train.ExponentialMovingAverage(cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
while True:
eval_once(saver, summary_writer, top_k_op, summary_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
其中用到了一个函数tf.nn.in_top_k(predictions, targets, k, name=None)
来判断targes是否在前k个预测中。