代码
a = np.arange(2*3*4)
a = d.reshape([2,3,4])
b = np.zeros([2,3,1]).astype('int)
c = np.take_along_axis(a,b,-1)
结果
a.shape = (2, 3, 4)
b.shape = (2, 3, 1)
c,shape = (2, 3, 1)
a = array([[
[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]
]])
b = array([[
[0],
[0],
[0]],
[[0],
[0],
[0]
]])
c = array([[
[ 0],
[ 4],
[ 8]],
[[12],
[16],
[20]
]])
等价于
torch.gather(a,-1,b)
参考网站: https://stackoverflow.com/questions/37878946/indexing-one-array-by-another-in-numpy