tf.gather
tf.gather(params, indices, validate_indices=None, name=None, axis=0)
params 表示你输入的张量,indices表示你想要params张量中切片的维度,所以这个函数就是挑选出params中indices对应的数。
举例子
x = tf.constant(np.arange(8).reshape((2,2,2)))
y = tf.gather(x,[0])
sess = tf.Session()
print(sess.run(y))
print('---------')
print(sess.run(x))
[[[0 1]
[2 3]]]
---------
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
就相当把x矩阵的第一个切片取出来了
想要其他轴切片,可以设置axis这个参数