参考链接传送门
tf.gather():
tf.gather() 单一维度方向的数据,进行任意顺序的切片。
-
源张量是一维数据
- 代码:
输出:import tensorflow as tf print("源张量") a = tf.constant([0,1,2,3,4,5,6,7,8,9]) print(a) b = tf.gather(a, indices=[0,2,4,6,8,1,3,5,7,9]) print(b) b = tf.gather(a, indices=[0,2,4]) print(b)源张量 tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32) tf.Tensor([0 2 4 6 8 1 3 5 7 9], shape=(10,), dtype=int32) tf.Tensor([0 2 4], shape=(3,), dtype=int32)
- 代码:
-
源张量是二维数据
- 代码
输出:import tensorflow as tf print("原张量") a = tf.constant([[0,1,2,3,4],[5,6,7,8,9], [10,11,12,13,14]]) print(a) print("\n过滤之后张量(axis=0):") b = tf.gather(a, axis=0, indices=[0, 2]) #只能过滤一个维度方向 print(b) print("\n过滤之后张量(axis=1):") b = tf.gather(a, axis=1, indices=[1,3]) #只能过滤一个维度方向 print(b)<tf.Tensor: shape=(3, 5), dtype=int32, numpy= array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]])> 过滤之后张量(axis=0): <tf.Tensor: shape=(2, 5), dtype=int32, numpy= array([[ 0, 1, 2, 3, 4], [10, 11, 12, 13, 14]])> 过滤之后张量(axis=1): <tf.Tensor: shape=(3, 2), dtype=int32, numpy= array([[ 1, 3], [ 6, 8], [11, 13]])>
- 代码
本文详细介绍了TensorFlow中的tf.gather()函数,展示了如何使用该函数从一维和二维张量中按指定索引进行数据切片。通过示例代码解释了该函数在不同维度过滤数据的操作,并给出了相应的输出结果。

被折叠的 条评论
为什么被折叠?



