tf.cond 与 tf.control_dependencies 的控制问题

原创 2017年04月18日 15:18:41

问题引入


在搜索tf.cond的使用方法时,找到了这样的一个问题:

运行下面的一段tensorflow代码:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

从代码上看,tf.cond经过判断pred的值对x进行更新。但实际上无论在pred = Ture 还是 False,输出的结果都是2,都是pred = tf.constant(True)的情况。

Confused by the behavior of tf.cond

这是怎么回事呢?

顺序执行


先不进行解释,有人在回复中给出了一个可以正确运行的代码,看一下有什么区别:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

区别也不大,只是把assign_x_2 = tf.assign(x, [2])这句整体移动到了tf.control_dependencies([tf.assign(x, [2])])的内部。
给出的解释是:

如果要让tf.cond()在其中一个分支中执行命令(如分配),你必须在你要传递给的函数创建执行副命令的操作。
If you want to perform a side effect (like an assignment) in one of the branches, you must create the op that performs the side effect inside the function that you pass to .
因为在TensorFlow图中的执行是依次向前流过图形的,所以在任一分支中引用的所有操作必须在条件进行求值之前执行。这意味着true和false分支都接受对tf.assign() op 的控制依赖。
Because execution in a TensorFlow graph flows forward through the graph, all operations that you refer to in either branch must execute before the conditional is evaluated. This means that both the true and the false branches receive a control dependency on the tf.assign() op.

翻译的可能不够准确,大意就是assign_x_2 = tf.assign(x, [2])这句话在tf.cond已经执行过了,因此无论执行update_x_2(让x=2)或lambda: tf.identity(x)(保持x不变),得到的结果都是x=2
这么来看其实是一个很简单的问题,定义时不仅定义了模型,也隐含着定义了执行顺序。

tf.control_dependencies()


这个函数加不加看起来没有什么区别,比如:

import tensorflow as tf                                                                                                                                
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
# x_2 = tf.assign(x, [2])
def update_x_2():
     # with tf.control_dependencies([x_2]): #[tf.assign(x, [2])]):
     return tf.assign(x, [2])
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
     session.run(tf.global_variables_initializer())
     print(y.eval(feed_dict={pred: False}))  # ==> [1]
     print(y.eval(feed_dict={pred: True}))   # ==> [2]

去掉之后运行结果和正确的相同。具体作用还是看一下官网吧……
直接搜tf.control_dependencies得到的信息并不多:

Wrapper for Graph.control_dependencies() using the default graph.
See tf.Graph.control_dependencies for more details.

tf.Graph.control_dependencies这里确实讲得很详细,其作用简单来说就是控制计算顺序

with g.control_dependencies([a, b, c]):
  # `d` and `e` will only run after `a`, `b`, and `c` have executed.
  d = ...
  e = ...

有了这句话,with中的语句就会在control_dependencies()中的操作执行之后运行,并且也支持嵌套操作。在给出的错误例子中,很像开头提出的问题:

# WRONG
def my_func(pred, tensor):
  t = tf.matmul(tensor, tensor)
  with tf.control_dependencies([pred]):
    # The matmul op is created outside the context, so no control
    # dependency will be added.
    return t

# RIGHT
def my_func(pred, tensor):
  with tf.control_dependencies([pred]):
    # The matmul op is created in the context, so a control dependency
    # will be added.
    return tf.matmul(tensor, tensor)

上面t操作在tf.control_dependencies之前已经被执行了,因此就无法控制t的先后顺序。如果我们把my_func看作是tf.cond中的分支操作函数,那么很可能在pred更新之前就已经进行了操作,因此可能造成一些错误。

总结


这么一看,好像我自己写的没有注意这么多细节,但目前从结果上看好像还都没什么问题,或许需要重新改写一下。

版权声明:本文为博主原创文章,转载请标注出处。

tf.cond 函数用法

z = tf.multiply(a, b) result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))If x < y, t...
  • Eric_LH
  • Eric_LH
  • 2017年12月29日 16:52
  • 136

tf.control_dependencies()作用及用法

在有些机器学习程序中我们想要指定某些操作执行的依赖关系,这时我们可以使用tf.control_dependencies()来实现。 control_dependencies(control_inpu...
  • PKU_Jade
  • PKU_Jade
  • 2017年06月20日 15:33
  • 5803

tf.cond()的用法

由于tensorflow使用的是graph的计算概念,在没有涉及控制数据流向的时候编程和普通编程语言的编程差别不大,但是涉及到控制数据流向的操作时,就要特别小心,不然很容易出错。这也是TensorFl...
  • m0_37041325
  • m0_37041325
  • 2017年08月08日 15:18
  • 4835

tf.cond和tf.case

tensorflow的逻辑控制 关于tf.case理解的不是很好: https://stackoverflow.com/questions/41910073/tensorflow-tf-case-...
  • u014221266
  • u014221266
  • 2017年11月16日 19:25
  • 513

tf.control_dependencies()

参考这里点击打开链接的信息我们可以知道,TF可以协调多个数据流,在存在依赖的节点下非常有用,例如节点B要读取模型参数值V更新后的值,而节点A负责更新参数V,所以节点B就要等节点A执行完成后再执行,不然...
  • m0_37041325
  • m0_37041325
  • 2017年08月08日 20:42
  • 956

(Tensorflow之十)tf.control_dependencies()用法

先看一下官方API文档 control_dependencies是用于控制计算流图的先后顺序的。必需先完成control_input的计算,才能执行之后定义的context。 但是,ten...
  • abiggg
  • abiggg
  • 2018年01月10日 01:36
  • 86

tensorflow学习笔记(四十一):control dependencies

tensorflowtf.control_dependencies()设计是用来控制计算流图的,给图中的某些计算指定顺序。比如:我们想要获取参数更新后的值,那么我们可以这么组织我们的代码。 opt =...
  • u012436149
  • u012436149
  • 2017年05月14日 23:48
  • 6796

tensorflow训练时的一些注意事项

1,使用batch norm层后,计算损失时,注意添加相应操作 def conv_bn_relu(inputs, num_outputs, phase, kernel_size, stride=1, ...
  • ying86615791
  • ying86615791
  • 2017年06月28日 21:56
  • 910

BN(batch Normalization)笔记

l  BN(batch Normalization) What is BN 通常在神经网络训练开始前,都要对输入数据做一个归一化处理 Why BN? 1.     提升泛华能力 神经网络学习过程本质就...
  • zwlq1314521
  • zwlq1314521
  • 2017年09月11日 11:16
  • 226

TensoFlow实现条件语句

import tensorflow as tf a = tf.constant(20) b = tf.constant(10) result1 = tf.cond(a > b, lambda: a...
  • helei001
  • helei001
  • 2016年11月11日 10:40
  • 4258
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:tf.cond 与 tf.control_dependencies 的控制问题
举报原因:
原因补充:

(最多只允许输入30个字)