TensorFlow的tf.where函数详解与例子

官方说明:
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.

If both non-None, condition, x and y must be broadcastable to the same shape.

The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).

官方文档很抽象,必须结合例子来理解。一共有两种用法,分别是带有xy参数和不带这两个参数的用法。

用法1
a1=np.array([[1,0,0],[0,1,1]]) 
a2=np.array([[3,2,3],[4,5,6]])
tf.where(tf.equal(a1,1),a1,a2)

输出的结果是

<tf.Tensor: id=13, shape=(2, 3), dtype=int64, numpy=
array([[1, 2, 3],
       [4, 1, 1]])>

也就是,当condition为真,也就是tf.equal(a1,1,即a1中的元素为1,返回的数组中所在位置元素来自a1,否则来自b1。输出的数组中,原数组a1不等于1的元素被替换成了对应位置b1中的元素。
再来一个例子,

tf.where(tf.equal(a1,1),a1,100+a1)

输出的结果是

<tf.Tensor: id=19, shape=(2, 3), dtype=int64, numpy=
array([[  1, 100, 100],
       [100,   1,   1]])>

数组a1中不等于1的元素,其值加上100。

用法2

不带xy参数的时候,返回满足condition的元素所在位置。需要关注的是返回值的形式。

tf.where(tf.equal(a,1)) 

输出结果

<tf.Tensor: id=55, shape=(3, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 1],
       [1, 2]])>

这是一个(3, 2)数组,行数表示满足条件的元素的数目a1中一共有3个元素为1,所有行数为3。每一列代表的是符合条件的元素的坐标,比如第一个元素[0,0],表示第一个满足条件的元素的index是(0,0)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值