第一种形式:tf.where(tensor),numpy也可以,返回其中为true的元素的索引
y_true = np.array([0, 1, 0, 1, 0])
a = tf.where(y_true)
with tf.Session() as sess:
print(a.eval())
输出:
[[1]
[3]]
第二种形式:tf.where(tensor,a,b), 也可以为numpy, 将tensor中的true位置元素替换为a中对应位置元素,false的替换为b中对应位置元素。
y_true = np.array([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]])
a = tf.where(y_true)
with tf.Session() as sess:
print(a.eval())
输出:
[[0 1]
[0 3]
[1 1]
[1 3]]
公众号: