tf.where(input, name=None)
返回一个布尔张量的真值的位置
此操作返回输入中真实元素的坐标。坐标在二维张量中返回,其中第一个维度(行)表示真实元素的数量,第二个维度(列)表示真实元素的坐标。记住,输出张量的形状可以根据输入中有多少个真值而变化。索引以行为主的顺序输出。
例子:
# 'input' tensor is [[True, False]
# [True, False]]
# 'input' has two true values, so output has two coordinates.
# 'input' has rank of 2, so coordinates have two indices.
where(input) ==> [[0, 0],
[1, 0]]
***注释:输出shape=[2,2],第一个维度是2,表示有两个真值,第二个维度表示真值在inpu中对应的坐标的维度。
[0,0]表示第一个True在input中的对应坐标是第0行,第0列。
[1,0]表示第二个True在input中对应的坐标是第1行,第0列(也就是坐标索引值)***
# `input` tensor is [[[True, False]
# [True, False]]
# [[False, True]
# [False, True]]
# [[False, False]
# [False, True]]]
# 'input' has 5 true values, so output has 5 coordinates.
# 'input' has rank of 3, so coordinates have three indices.
where(input) ==> [[0, 0, 0],
[0, 1, 0],
[1, 0, 1],
[1, 1, 1],
[2, 1, 1]]
**注释:输出shape=[5,3]。input有5个True,所以输出有5行,shape(input)=[3,2,2],
所以输出每一行都有三列,代表每一个True在input中的对应坐标。
以第一行[0,0,0]为例解释一下怎么来的,第一个True在input的第一维中站在第一个,所以第一个是0,在第二维中站在第一个,
所以第二个还是0,在第三维中还是站在第一个,所以第三个还是0**
实际上,tf.where还有一个用法,tf.where(input, a,b)
,其中a,b均为尺寸一致的tensor,实现a中对应input中true的位置的元素值不变,其余元素由b中对应位置元素替换。
import tensorflow as tf
import numpy as np
sess=tf.InteractiveSession( )
a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
print(sess.run(tf.equal(a,1)),'\n')
print(sess.run(tf.equal(a,0)),'\n')
print(sess.run(tf.where(tf.equal(a,1),a,a1)),'\n')
print(sess.run(tf.where(tf.equal(a,0),a,1-a1)),'\n')
输出结果:
[[ True False False]
[False True True]]
[[False True True]
[ True False False]]
[[1 2 3]
[4 1 1]]
[[-2 0 0]
[ 0 -4 -5]]