# tf.cond(
# pred,
# true_fn=None,
# false_fn=None,
# strict=False,
# name=None,
# fn1=None,
# fn2=None
# )
# tensorflow下的三目运算符
import tensorflow as tf
x = tf.constant(1.0)
y = tf.constant(2.0)
z = tf.constant(3.0)
def f1():
return tf.Print(x, [x])
def f2():
return tf.Print(y, [y])
op = tf.cond(x > y, true_fn=f2, false_fn=f1)
with tf.Session() as sess:
sess.run(op)
如果pred正确,执行true_fn,否则执行false_fn。
#参考https://stackoverflow.com/questions/47768298/how-to-understand-using-tf-cond-with-tf-print