tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
作用:根据indeces
收集在params
上的值。如图:
例子:
import tensorflow as tf
temp = tf.range(0,10)*10 + tf.constant(1,shape=[10])
#收集下标1、5、9处的值
temp2 = tf.gather(temp,[1,5,9])
with tf.Session() as sess:
print(sess.run(temp))
print(sess.run(temp2))
输出
[ 1 11 21 31 41 51 61 71 81 91]
[11 51 91]
参考:https://blog.csdn.net/guotong1988/article/details/53172882