tf.cond 与 tf.control_dependencies 的控制问题

原创 2017年04月18日 15:18:41

问题引入


在搜索tf.cond的使用方法时,找到了这样的一个问题:

运行下面的一段tensorflow代码:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

从代码上看,tf.cond经过判断pred的值对x进行更新。但实际上无论在pred = Ture 还是 False,输出的结果都是2,都是pred = tf.constant(True)的情况。

Confused by the behavior of tf.cond

这是怎么回事呢?

顺序执行


先不进行解释,有人在回复中给出了一个可以正确运行的代码,看一下有什么区别:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

区别也不大,只是把assign_x_2 = tf.assign(x, [2])这句整体移动到了tf.control_dependencies([tf.assign(x, [2])])的内部。
给出的解释是:

如果要让tf.cond()在其中一个分支中执行命令(如分配),你必须在你要传递给的函数创建执行副命令的操作。
If you want to perform a side effect (like an assignment) in one of the branches, you must create the op that performs the side effect inside the function that you pass to .
因为在TensorFlow图中的执行是依次向前流过图形的,所以在任一分支中引用的所有操作必须在条件进行求值之前执行。这意味着true和false分支都接受对tf.assign() op 的控制依赖。
Because execution in a TensorFlow graph flows forward through the graph, all operations that you refer to in either branch must execute before the conditional is evaluated. This means that both the true and the false branches receive a control dependency on the tf.assign() op.

翻译的可能不够准确,大意就是assign_x_2 = tf.assign(x, [2])这句话在tf.cond已经执行过了,因此无论执行update_x_2(让x=2)或lambda: tf.identity(x)(保持x不变),得到的结果都是x=2
这么来看其实是一个很简单的问题,定义时不仅定义了模型,也隐含着定义了执行顺序。

tf.control_dependencies()


这个函数加不加看起来没有什么区别,比如:

import tensorflow as tf                                                                                                                                
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
# x_2 = tf.assign(x, [2])
def update_x_2():
     # with tf.control_dependencies([x_2]): #[tf.assign(x, [2])]):
     return tf.assign(x, [2])
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
     session.run(tf.global_variables_initializer())
     print(y.eval(feed_dict={pred: False}))  # ==> [1]
     print(y.eval(feed_dict={pred: True}))   # ==> [2]

去掉之后运行结果和正确的相同。具体作用还是看一下官网吧……
直接搜tf.control_dependencies得到的信息并不多:

Wrapper for Graph.control_dependencies() using the default graph.
See tf.Graph.control_dependencies for more details.

tf.Graph.control_dependencies这里确实讲得很详细,其作用简单来说就是控制计算顺序

with g.control_dependencies([a, b, c]):
  # `d` and `e` will only run after `a`, `b`, and `c` have executed.
  d = ...
  e = ...

有了这句话,with中的语句就会在control_dependencies()中的操作执行之后运行,并且也支持嵌套操作。在给出的错误例子中,很像开头提出的问题:

# WRONG
def my_func(pred, tensor):
  t = tf.matmul(tensor, tensor)
  with tf.control_dependencies([pred]):
    # The matmul op is created outside the context, so no control
    # dependency will be added.
    return t

# RIGHT
def my_func(pred, tensor):
  with tf.control_dependencies([pred]):
    # The matmul op is created in the context, so a control dependency
    # will be added.
    return tf.matmul(tensor, tensor)

上面t操作在tf.control_dependencies之前已经被执行了,因此就无法控制t的先后顺序。如果我们把my_func看作是tf.cond中的分支操作函数,那么很可能在pred更新之前就已经进行了操作,因此可能造成一些错误。

总结


这么一看,好像我自己写的没有注意这么多细节,但目前从结果上看好像还都没什么问题,或许需要重新改写一下。

版权声明:本文为博主原创文章,转载请标注出处。

相关文章推荐

tf.cond()的用法

由于tensorflow使用的是graph的计算概念,在没有涉及控制数据流向的时候编程和普通编程语言的编程差别不大,但是涉及到控制数据流向的操作时,就要特别小心,不然很容易出错。这也是TensorFl...

Theano-Deep Learning Tutorials 笔记:Stacked Denoising Autoencoders (SdA)

栈式降噪自编码,先逐层无监督预训练,再整个网络有监督微调,python实现。...

TF31002问题解决办法

TFS连接出现TF31002问题的解决办法;自己总结,希望对您有帮助

阅读源码遇到的一些TF、keras函数及问题2(--小白笔记)

numpyhstackab与numpyvstackab numpytileab keraslayerscoreDense keraslayersconvolutionalConvolution2D k...

TFS问题集 TF30224: 未能从报表服务器检索项目.

---开始异常项--- 时间: 2011-11-20 20:10:10Z 模块: Initializer 事件说明: TF30207: 插件“Microsoft.ProjectCreationWiza...
  • FBug
  • FBug
  • 2011年11月20日 23:24
  • 1597

应用于文本分类问题的TF-IDF改进方法

TF-IDF是一种统计方法,用以评估某一字词对于一个文件集或一个语料库中的其中一份文件的重要程度。字词的重要性随着它在文件中出现的次数成正比增加,但同时会随着它在语料库中出现的频率成反比下降。二、传统...

tensorflow conv2d padding,tf图像卷积边缘扩展问题

初学tensorflow的conv2d的时候,一般书上会说conv2d的扩展可以选择两种,SAME和VALID。这两种要么导致图像变小(valid),要么导致边缘变黑(same),因为边缘只补0。曾一...

远程控制问题集锦(你的凭据不工作,之前用于连接到(服务器IP)的凭据无法工作,请输入新的凭据)

成也名称,败也名称,由于忘了自己的电脑用户名,远程的时候走了很多弯路,但是从中了解到了很多远程知识;你的凭据不工作,之前用于连接到(服务器IP)的凭据无法工作,请输入新的凭据。...
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:tf.cond 与 tf.control_dependencies 的控制问题
举报原因:
原因补充:

(最多只允许输入30个字)