声明:
tf.control_dependecies()
tf.control_dependencies
是tensorflow中的一个flow顺序控制机制,作用有二:插入依赖(dependencies)和清空依赖(依赖是op或tensor)。常见的tf.control_dependencies
是tf.Graph.control_dependencies
的装饰器,它们用法是一样的。通过本文,你将了解:
- 了解
control_dependencies()
的顺序控制机制 tf.control_dependencies()
在batch normalization中的使用示例control_dependencies()
两种不正确的使用方式
control_dependencies介绍
tf.control_dependencies()
有一个参数control_inputs
(这是一个列表,列表中可以是Operation
或Tensor
对象),返回一个上下文管理器(通常与with
一起使用)。
例1
-
with tf.control_dependencies([a, b, c]):
-
d = ...
-
e = ...
session在运行d、e之前会先运行a、b、c。在with tf.control_dependencies
之内的代码块受到顺序控制机制的影响。
例2
-
with tf.control_dependencies([a, b]):
-
with tf.control_dependencies([c, d]):
-
e = ...
session在运行e之前会先运行a、b、c、d。因为依赖会随着with tf.control_dependencies
的嵌套一直继承下去。
例3
-
with tf.control_dependencies([a, b]):
-
with tf.control_dependencies(
None):
# 第二层上下文管理器
-
with tf.control_dependencies([c, d]):
-
e = ...
session在运行e之前会先运行c、d,不需要运行a、b。因为在第二层的上下文管理器中,参数control_inputs
的值为None
,如此将会清除之前所有的依赖。
在BN中的使用
tf.layers.batch_normalization()
中用到的变量——当前估计的均值和方差是untrainable的,它们通过每个batch的均值和标准差的移动平均更新值,位于collection——tf.GraphKeys.UPDATE_OPS
中。因此,需要在每一轮迭代前插入这个操作:
-
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
-
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
这样,在每次迭代调用train_step
之前都会先当前batch下的均值和方差的移动平均值。
几种不正确的使用方式
例1
op或tensor在流图中的顺序由它们的创建位置决定。
-
# 不正确
-
def my_func(pred, tensor):
-
t = tf.matmul(tensor, tensor)
-
with tf.control_dependencies([pred]):
-
# matmul op的定义在context之外,context内只有一个op或tensor不能继承依赖。
-
return t
-
-
# 应改为
-
def my_func(pred, tensor):
-
with tf.control_dependencies([pred]):
-
# 应将t的创建放到context之内
-
t = tf.matmul(tensor, tensor)
-
return t
例2
tensorflow在求导过程中自动忽略常数项。
-
# 不正确
-
loss = model.loss()
-
with tf.control_dependencies(dependencies):
-
loss = loss + tf.constant(
1)
-
return tf.gradients(loss, model.variables)
因为常数项tf.constant(1)
在back propagation时被忽略了,所以依赖性在BP的时候也不会被执行。
例3
-
w = tf.Variable(1.0)
-
ema = tf.train.ExponentialMovingAverage(0.9)
-
update = tf.assign_add(w,
1.0)
-
-
ema_op = ema.apply([
update])
-
with tf.control_dependencies([ema_op]):
-
ema_val = ema.average(
update)
-
-
with tf.Session()
as sess:
-
tf.global_variables_initializer().run()
-
for i
in
range(
3):
-
print(sess.run([ema_val]))
-
-
# 应改为
-
with tf.control_dependencies([ema_op]):
-
ema_val = tf.identity(ema.average(
update)) # 加一个
identity
看起来好像是在运行ema_val
之前先执行ema_op
,实际不然。因为ema.average(update)不是一个op,它只是从
ema对象的一个字典中取出键对应的
tensor`而已。这个清空跟上文例一很像。
例4
-
import tensorflow as tf
-
w = tf.Variable(1.0)
-
ema = tf.train.ExponentialMovingAverage(0.9)
-
update = tf.assign_add(w,
1.0)
-
-
ema_op = ema.apply([
update])
-
with tf.control_dependencies([ema_op]):
-
w1 = tf.Variable(
2.0)
-
ema_val = ema.average(
update)
-
-
with tf.Session()
as sess:
-
tf.global_variables_initializer().run()
-
for i
in
range(
3):
-
print(sess.run([ema_val, w1]))
这种情况下,control_dependencies
也不工作,原因如下:
-
#这段代码出现在Variable类定义文件中第287行,
-
# 在创建Varible时,tensorflow是移除了dependencies了的
-
#所以会出现 control 不住的情况
-
with ops.control_dependencies(
None):
-
...
转自:
https://blog.csdn.net/hustqb/article/details/83545310