tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究
tf.control_dependencies
首先我们先介绍tf.control_dependencies
,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。请看下面一个例子。
import tensorflow as tf
a_1 = tf.Variable(1)
b_1 = tf.Variable(2)
update_op = tf.assign(a_1, 10)
add = tf.add(a_1, b_1)
a_2 = tf.Variable(1)
b_2 = tf.Variable(2)
update_op = tf.assign(a_2, 10)
with tf.control_dependencies([update_op]):
add_with_dependencies = tf.add(a_2, b_2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans_1, ans_2 = sess.run([add, add_with_dependencies])
print("Add: ", ans_1)
print("Add_with_dependency: ", ans_2)
输出:
Add: 3
Add_with_dependency: 12
可以看到两组加法进行的对比,正常的计算图在计算add时是不会经过update_op操作的,因此在加法时a的值为1,但是采用tf.control_dependencies函数,可以控制在进行add前先完成update_op的操作,因此在加法时a的值为10,因此最后两种加法的结果不同。
2、tf.GraphKeys.UPDATE_OPS
关于tf.GraphKeys.UPDATE_OPS,这是一个tensorflow的计算图中内置的一个集合,其中会保存一些需要在训练操作之前完成的操作,并配合tf.control_dependencies函数使用。
关于在batch_norm中,即为更新mean和variance的操作。通过下面一个例子可以看到tf.layers.batch_normalization中是如何实现的。
import tensorflow as tf
is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)
# 打印batch_normalization中的两个操作
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
# with tf.control_dependencies(update_ops):
# train_op = optimizer.minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, "batch_norm_layer/Model")
输出:
[<tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(3,) dtype=float32_ref>,
<tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(3,) dtype=float32_ref>]
可以看到输出的即为两个batch_normalization中更新mean和variance的操作,需要保证它们在train_op前完成。
这两个操作是在tensorflow的内部实现中自动被加入tf.GraphKeys.UPDATE_OPS这个集合的,在tf.contrib.layers.batch_norm的参数中可以看到有一项updates_collections的默认值即为tf.GraphKeys.UPDATE_OPS,而在tf.layers.batch_normalization中则是直接将两个更新操作放入了上述集合。