在TensorFlow中,tf.cond()类似于if...else...,用来控制数据流向,但是仅仅类似而已,其中差别还是挺大的。
format:tf.cond(pred, fn1, fn2, name=None)
Return :either fn1() or fn2() based on the boolean predicate `pred`. # (注意这里,也就是说'fnq'和‘fn2’是两个函数)
arguments:`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have the same non-zero number and type of outputs # ('fnq'和‘fn2’返回的是非零的且类型相同的输出)
官方例子:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
上面例子执行这样的操作,如果x<y为真则result这个操作是tf.add(x,z),反之则是tf.square(y)。这一点上,确实很像逻辑控制中的if...else...,但是官方说明里也提到Since z is needed for at least one branch of the cond,branch of the cond, the tf.mul operation is always executed, unconditionally.
因为z在cond函数中的至少一个分支被用到,所以
z = tf.multiply(a, b)
总是被无条件执行。
Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics.
翻译过来应该是:尽管这样的操作与TensorFlow的数据流模型一致,但是偶尔还是会令那些期望慵懒语法的用户吃惊。
因为TensorFlow是基于图的计算,数据以流的形式存在,所以只要构建好了图,有数据源,那么应该都会 数据流过,所以在执行tf.cond之前,两个数据流一个是tf.add()中的x,z,一个是tf.square(y)中的y,而tf.cond()就决定了是数据流x,z从tf.add()流过,还是数据流y从tf.square()流过。这里这个tf.cond也就像个控制水流的阀门,水流管道x,z,y在这个阀门交汇,而tf.cond决定了谁将流向后面的管道,但是不管哪一个水流流向下一个管道,在阀门作用之前,水流应该都是要到达阀门的。
示例:
import tensorflow as tf
a=tf.constant(2)
b=tf.constant(3)
x=tf.constant(4)
y=tf.constant(5)
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
with tf.Session() as session:
print(result.eval())