a=tf.constant([[[1,2,3],[4,5,6]],
[[7,8,9],[10,11,12]]])
# 选择索引0和索引2的tensor
indices=tf.constant([0,2])
#tensor为a,维度为2,索引为0和2
b=tf.gather(a,indices,**axis=2**) # 注意axis=2,axis的选项必须指定
with tf.Session() as sess:
# result2 = sess.run(b)
print(sess.run(b))
输出结果为
tensor([[[ 1, 3],
[ 4, 6]],
[[ 7, 9],
[10, 12]]])
等同于torch的如下处理方法
import torch
#shape为(2,2,3)
a=torch.tensor([[[1,2,3],[4,5,6]],
[[7,8,9],[10,11,12]]])
#选择索引0和索引2的tensor
indices=torch.tensor([0,2])
#tensor为a,维度为2,索引为0和2
b=torch.index_select(a,2,indices)