在使用tf.case()函数时遇到这个错误
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
sess = tf.InteractiveSession()
x = tf.random_uniform(shape=[], minval=-1, maxval=1)
y = tf.random_uniform(shape=[], minval=-1, maxval=1)
out = tf.case([(tf.less(x, y), lambda: x+y), (tf.greater(x, y), lambda: x-y)], default=0, exclusive=True)
print(sess.run(out))
因为里面只有default=0用到了int,确实不太符合tensorflow习惯,所以用tf.constant代替,然后遇到新的错误:
TypeError: true_fn and false_fn arguments to tf.cond must have the same number, type, and overall structure of return values.
true_fn output: Tensor("add:0", shape=(), dtype=float32)
false_fn output: Tensor("Const:0", shape=(), dtype=int32)
Error details:
Tensor("add:0", shape=(), dtype=float32) and Tensor("Const:0", shape=(), dtype=int32) have different types
这个比较好理解,我们的0是int类型,替换为0.0,问题解决。
此外我尝试了不同的tf.case格式(字典-v1与列表-v2)都可以运行。