一、where:采集元素
1、从1个Tensor中采集元素【indices=where(mask)、gather_nd(input, indices)】
1.1 使用boolean_mask采集元素
import tensorflow as tf
a = tf.convert_to_tensor(
[[0.79134136, 0.09345922, -0.7822895],
[1.9430199, -0.2962239, -1.1451387],
[0.35126936, 1.0099757, 0.67769486]])
print("a = \n", a)
print("-" * 200)
mask = a > 0
print("mask = \n", mask)
print("-" * 100)
b = tf.boolean_mask(a, mask)
print("b = \n", b)
print("-" * 200)
打印结果:
a =
tf.Tensor(
[[ 0.79134136 0.09345922 -0.7822895 ]
[ 1.9430199 -0.2962239 -1.1451387 ]
[ 0.35126936 1.0099757 0.67769486]], shape=(3, 3), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
mask =
tf.Tensor(
[[ True True False]
[ True False False]
[ True True True]], shape=(3, 3), dtype=bool)
----------------------------------------------------------------------------------------------------
b =
tf.Tensor([0.79134136 0.09345922 1.9430199 0.35126936 1.0099757 0.67769486], shape=(6,), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Process finished with exit code 0
1.2 使用where采集元素
import tensorflow as tf
a = tf.convert_to_tensor(
[[0.79134136, 0.09345922, -0.7822895],
[1.9430199, -0.2962239, -1.1451387],
[0.35126936, 1.0099757, 0.67769486]])
print("a = \n", a)
print("-" * 200)
mask = a > 0
print("mask = a > 0 = \n", mask)
print("-" * 100)
indices = tf.where(mask)
print("indices = tf.where(mask) = \n", indices)
print("-" * 100)
d = tf.gather_nd(a, indices)
print("d = tf.gather_nd(a, indices) = \n", d)
print("-" * 200)
打印结果:
a =
tf.Tensor(
[[ 0.79134136 0.09345922 -0.7822895 ]
[ 1.9430199 -0.2962239 -1.1451387 ]
[ 0.35126936 1.0099757 0.67769486]], shape=(3, 3), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
mask = a > 0 =
tf.Tensor(
[[ True True False]
[ True False False]
[ True True True]], shape=(3, 3), dtype=bool)
----------------------------------------------------------------------------------------------------
indices = tf.where(mask) =
tf.Tensor(
[[0 0]
[0 1]
[1 0]
[2 0]
[2 1]
[2 2]], shape=(6, 2), dtype=int64)
----------------------------------------------------------------------------------------------------
d = tf.gather_nd(a, indices) =
tf.Tensor([0.79134136 0.09345922 1.9430199 0.35126936 1.0099757 0.67769486], shape=(6,), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Process finished with exit code 0
2、从2个Tensor中采集元素【where(cond, A, B)】
where(condition, x=None, y=None, name=None)
- tf.where 通过boolean矩阵的 true or false 对候选条件下的两个矩阵进行element选取
- 这里true就选x中的元素,false就选y中的元素
import tensorflow as tf
a = tf.where([[True, False], [False, True]], x=[[1, 2], [3, 4]], y=[[5, 6], [7, 8]])
print("a = \n", a)
打印结果:
a =
tf.Tensor(
[[1 6]
[7 4]], shape=(2, 2), dtype=int32)
参考资料:
2020-06-05-tensorflow2-tf.where说明和例子
TensorFlow的tf.where函数详解与例子
TensorFlow函数:tf.where