tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from `params` axis `axis` according to `indices`.
从'params'的'axis'维根据'indices'的参数值获取切片。就是在axis维根据indices取某些值。
参考博客https://blog.csdn.net/guotong1988/article/details/53172882,尝试实例
import tensorflow as tf
temp4=tf.reshape(tf.range(0,20)+tf.constant(1,shape=[20]),[2,2,5])
temp5=tf.gather(temp4,[0,1],axis=0) #indices是向量
temp6=tf.gather(temp4,1,axis=1) #indices是数值
temp7=tf.gather(temp4,[1,4],axis=2)
temp8=tf.gather(temp4,[[0,1],[3,4]],axis=2) #indices是多维的
with tf.Session() as sess:
print(sess.run(temp4))
print(sess.run(temp5))
print(sess.run(temp6))
print(sess.run(temp7))
print(sess.run(temp8))
输出:
temp4:
[[[ 1 2 3 4 5] [ 6 7 8 9 10]] [[11 12 13 14 15] [16 17 18 19 20]]]
temp5:(当indices是向量时,输出的形状和输入形状相同,不改变)
[[[ 1 2 3 4 5] [ 6 7 8 9 10]] [[11 12 13 14 15] [16 17 18 19 20]]]
temp6:(当indices是数值时,输出的形状比输入的形状少一维)
[[ 6 7 8 9 10] [16 17 18 19 20]]
temp7:
[[[ 2 5] [ 7 10]] [[12 15] [17 20]]]
temp8:(当indices是多维时,输出形状为“rank(params) + rank(indices) - 1”,即"params的维数+indices的维数-1")
[[[[ 1 2] [ 4 5]] [[ 6 7] [ 9 10]]] [[[11 12] [14 15]] [[16 17] [19 20]]]]