tensorflow之control_dependencies 自学用

声明:

  1. 翻译tensorflow官方文档并进行了总结
  2. 参考博客tensorflow学习笔记(四十一):control dependencies

tf.control_dependecies()

tf.control_dependencies是tensorflow中的一个flow顺序控制机制,作用有二:插入依赖(dependencies)和清空依赖(依赖是op或tensor)。常见的tf.control_dependenciestf.Graph.control_dependencies的装饰器,它们用法是一样的。通过本文,你将了解:

  • 了解control_dependencies()的顺序控制机制
  • tf.control_dependencies()在batch normalization中的使用示例
  • control_dependencies()两种不正确的使用方式

control_dependencies介绍

tf.control_dependencies()有一个参数control_inputs(这是一个列表,列表中可以是OperationTensor对象),返回一个上下文管理器(通常与with一起使用)。

例1


 
 
  1. with tf.control_dependencies([a, b, c]):
  2. d = ...
  3. e = ...

session在运行d、e之前会先运行a、b、c。在with tf.control_dependencies之内的代码块受到顺序控制机制的影响。

例2


 
 
  1. with tf.control_dependencies([a, b]):
  2. with tf.control_dependencies([c, d]):
  3. e = ...

session在运行e之前会先运行a、b、c、d。因为依赖会随着with tf.control_dependencies的嵌套一直继承下去。

例3


 
 
  1. with tf.control_dependencies([a, b]):
  2. with tf.control_dependencies( None): # 第二层上下文管理器
  3. with tf.control_dependencies([c, d]):
  4. e = ...

session在运行e之前会先运行c、d,不需要运行a、b。因为在第二层的上下文管理器中,参数control_inputs的值为None,如此将会清除之前所有的依赖。

在BN中的使用

tf.layers.batch_normalization()中用到的变量——当前估计的均值方差是untrainable的,它们通过每个batch的均值和标准差的移动平均更新值,位于collection——tf.GraphKeys.UPDATE_OPS中。因此,需要在每一轮迭代前插入这个操作:


 
 
  1. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
  2. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
  •  

这样,在每次迭代调用train_step之前都会先当前batch下的均值方差的移动平均值。

几种不正确的使用方式

例1

op或tensor在流图中的顺序由它们的创建位置决定。


 
 
  1. # 不正确
  2. def my_func(pred, tensor):
  3. t = tf.matmul(tensor, tensor)
  4. with tf.control_dependencies([pred]):
  5. # matmul op的定义在context之外,context内只有一个op或tensor不能继承依赖。
  6. return t
  7. # 应改为
  8. def my_func(pred, tensor):
  9. with tf.control_dependencies([pred]):
  10. # 应将t的创建放到context之内
  11. t = tf.matmul(tensor, tensor)
  12. return t

例2

tensorflow在求导过程中自动忽略常数项。


 
 
  1. # 不正确
  2. loss = model.loss()
  3. with tf.control_dependencies(dependencies):
  4. loss = loss + tf.constant( 1)
  5. return tf.gradients(loss, model.variables)

因为常数项tf.constant(1)在back propagation时被忽略了,所以依赖性在BP的时候也不会被执行。

例3


 
 
  1. w = tf.Variable(1.0)
  2. ema = tf.train.ExponentialMovingAverage(0.9)
  3. update = tf.assign_add(w, 1.0)
  4. ema_op = ema.apply([ update])
  5. with tf.control_dependencies([ema_op]):
  6. ema_val = ema.average( update)
  7. with tf.Session() as sess:
  8. tf.global_variables_initializer().run()
  9. for i in range( 3):
  10. print(sess.run([ema_val]))
  11. # 应改为
  12. with tf.control_dependencies([ema_op]):
  13. ema_val = tf.identity(ema.average( update)) # 加一个 identity

看起来好像是在运行ema_val之前先执行ema_op,实际不然。因为ema.average(update)不是一个op,它只是从ema对象的一个字典中取出键对应的tensor`而已。这个清空跟上文例一很像。

例4


 
 
  1. import tensorflow as tf
  2. w = tf.Variable(1.0)
  3. ema = tf.train.ExponentialMovingAverage(0.9)
  4. update = tf.assign_add(w, 1.0)
  5. ema_op = ema.apply([ update])
  6. with tf.control_dependencies([ema_op]):
  7. w1 = tf.Variable( 2.0)
  8. ema_val = ema.average( update)
  9. with tf.Session() as sess:
  10. tf.global_variables_initializer().run()
  11. for i in range( 3):
  12. print(sess.run([ema_val, w1]))

这种情况下,control_dependencies也不工作,原因如下:


 
 
  1. #这段代码出现在Variable类定义文件中第287行,
  2. # 在创建Varible时,tensorflow是移除了dependencies了的
  3. #所以会出现 control 不住的情况
  4. with ops.control_dependencies( None):
  5. ...

 转自:

https://blog.csdn.net/hustqb/article/details/83545310

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值