tf.gather
根据索引从参数轴上收集切片,索引必须是任何维度的整数张量 (通常为 0-D 或 1-D)
import tensorflow as tf
t1 = tf.reshape(tf.range(0,16),[2,2,4])
# [[[ 0 1 2 3]
# [ 4 5 6 7]]
#
# [[ 8 9 10 11]
# [12 13 14 15]]]
# 取第2维中,index为1,3的数据
t2 = tf.gather(t1,[1,3],axis=2)
with tf.Session() as sess:
print(sess.run(t1))
print(sess.run(t2))
[[[ 1 3]
[ 5 7]]
[[ 9 11]
[13 15]]]