tf.gather:
函数原型:
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
说明:
- 根据indices中的索引从params中取出响应的元素来替换掉索引形成张量返回。
例子:
import tensorflow as tf
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
index_a = tf.Variable([0, 2])
index_a1=tf.Variable([[0], [2]])
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(a, index_a)))
print(sess.run(tf.gather(b, index_b)))
print(sess.run(tf.gather(a, index_a1)))
输出:
[[ 1 2 3 4 5]
[11 12 13 14 15]]
================================
[3 5 7 9]
================================
[[[ 1 2 3 4 5]]
[[11 12 13 14 15]]]