高阶操作
1 where
where(mask)
筛选出矩阵中true位置的坐标:
>>> a = tf.random.normal([3,3])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-0.9279645 , 0.04238617, 0.1136281 ],
[-0.91321355, 0.96097076, -0.9072119 ],
[ 0.699401 , -0.39297295, 0.73130745]], dtype=float32)>
>>> mask = a>0
<tf.Tensor: shape=(3, 3), dtype=bool, numpy=
array([[False, True, True],
[False, True, False],
[ True, False, True]])>
>>> indices = tf.where(mask)
<tf.Tensor: shape=(5, 2), dtype=int64, numpy=
array([[0, 1],
[0, 2],
[1, 1],
[2