下面我们直接看使用方法和功能介绍:
import tensorflow as tf
import numpy as np
"""
1. 该函数的参数讲解:
tf.gather(
params, 传入的tensor
indices, 指定的索引
validate_indices=None, 不重要
name=None, 命名
axis=0 指定轴
2. 功能:
就是抽取出params的第axis维度上在indices里面所有的index
需要注意的是indices里面最大值需要小等于params在指定的axis下ndim的长度。
"""
c1 = tf.constant(np.arange(7), shape=[7])
g1 = tf.gather(c1, indices=[1, 3]) # 获取索引为1和3的值
c2 = tf.constant([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4],
[5, 5, 5]], shape=[5, 3])
g2 = tf.gather(c2, indices=[1, 4], axis=0) # 获取在第1维度上的,索引为1和4的值
with tf.Session() as sess:
c1_, g1_, c2_, g2_ = sess.run(fetches=[c1, g1, c2, g2])
print(c1_)
print('-' * 100)
print(g1_)
print('*' * 100)
print('*' * 100)
print(c2_)
print('-' * 100)
print(g2_)
结果查看:
[0 1 2 3 4 5 6]
-------------------------------------------------------------------------------------------
[1 3]
*******************************************************************************************
*******************************************************************************************
[[1 1 1]
[2 2 2]
[3 3 3]
[4 4 4]
[5 5 5]]
-------------------------------------------------------------------------------------------
[[2 2 2]
[5 5 5]]