tf.cond()是一个条件函数,根据条件返回的True或False 返回相应的结果
第一个参数是条件 bool 类型,第2个和第3个参数是返回的值,如果条件是True 返回第二个参数,如果条件是False 则返回第三个参数
import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
x = tf.constant(4)
y = tf.constant(5)
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
with tf.Session() as session:
print(result.eval())