TensorFlow2-高阶操作(七):where(采集元素)【从1个Tensor中采集元素】【从2个Tensor中采集元素:where(cond, A, B)】

一、where:采集元素

1、从1个Tensor中采集元素【indices=where(mask)、gather_nd(input, indices)】

在这里插入图片描述

1.1 使用boolean_mask采集元素

import tensorflow as tf

a = tf.convert_to_tensor(
    [[0.79134136, 0.09345922, -0.7822895],
     [1.9430199, -0.2962239, -1.1451387],
     [0.35126936, 1.0099757, 0.67769486]])
print("a = \n", a)
print("-" * 200)

mask = a > 0
print("mask = \n", mask)
print("-" * 100)

b = tf.boolean_mask(a, mask)
print("b = \n", b)
print("-" * 200)

打印结果:

a = 
 tf.Tensor(
[[ 0.79134136  0.09345922 -0.7822895 ]
 [ 1.9430199  -0.2962239  -1.1451387 ]
 [ 0.35126936  1.0099757   0.67769486]], shape=(3, 3), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
mask = 
 tf.Tensor(
[[ True  True False]
 [ True False False]
 [ True  True  True]], shape=(3, 3), dtype=bool)
----------------------------------------------------------------------------------------------------
b = 
 tf.Tensor([0.79134136 0.09345922 1.9430199  0.35126936 1.0099757  0.67769486], shape=(6,), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0

1.2 使用where采集元素

import tensorflow as tf

a = tf.convert_to_tensor(
    [[0.79134136, 0.09345922, -0.7822895],
     [1.9430199, -0.2962239, -1.1451387],
     [0.35126936, 1.0099757, 0.67769486]])

print("a = \n", a)
print("-" * 200)

mask = a > 0
print("mask = a > 0 = \n", mask)
print("-" * 100)

indices = tf.where(mask)
print("indices = tf.where(mask) = \n", indices)
print("-" * 100)

d = tf.gather_nd(a, indices)
print("d = tf.gather_nd(a, indices) = \n", d)
print("-" * 200)

打印结果:

a = 
 tf.Tensor(
[[ 0.79134136  0.09345922 -0.7822895 ]
 [ 1.9430199  -0.2962239  -1.1451387 ]
 [ 0.35126936  1.0099757   0.67769486]], shape=(3, 3), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
mask = a > 0 = 
 tf.Tensor(
[[ True  True False]
 [ True False False]
 [ True  True  True]], shape=(3, 3), dtype=bool)
----------------------------------------------------------------------------------------------------
indices = tf.where(mask) = 
 tf.Tensor(
[[0 0]
 [0 1]
 [1 0]
 [2 0]
 [2 1]
 [2 2]], shape=(6, 2), dtype=int64)
----------------------------------------------------------------------------------------------------
d = tf.gather_nd(a, indices) = 
 tf.Tensor([0.79134136 0.09345922 1.9430199  0.35126936 1.0099757  0.67769486], shape=(6,), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0

2、从2个Tensor中采集元素【where(cond, A, B)】

在这里插入图片描述

where(condition, x=None, y=None, name=None)

  • tf.where 通过boolean矩阵的 true or false 对候选条件下的两个矩阵进行element选取
  • 这里true就选x中的元素,false就选y中的元素
import tensorflow as tf

a = tf.where([[True, False], [False, True]], x=[[1, 2], [3, 4]], y=[[5, 6], [7, 8]])

print("a = \n", a)

打印结果:

a = 
 tf.Tensor(
[[1 6]
 [7 4]], shape=(2, 2), dtype=int32)



参考资料:
2020-06-05-tensorflow2-tf.where说明和例子
TensorFlow的tf.where函数详解与例子
TensorFlow函数:tf.where

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值