tf.where
import tensorflow as tf
temp = tf.reshape(tf.range(0, 16) + tf.constant(1, shape=[16]), [4, 1, 2, 2])
category_index = tf.where(tf.greater(temp, 6))
with tf.Session() as sess:
a = sess.run(category_index)
temp = sess.run(tf.greater(temp, 6))
print(temp)
print(a)
解释
返回值为True的位置
[[[[False False]
[False False]]]
[[[False False]
[ True True]]]
[[[ True True]
[ True True]]]
[[[ True True]
[ True True]]]]
[[1 0 1 0]
[1 0 1 1]
[2 0 0 0]
[2 0 0 1]
[2 0 1 0]
[2 0 1 1]
[3 0 0 0]
[3 0 0 1]
[3 0 1 0]
[3 0 1 1]]