先来看官方文档的解释:
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If input is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1)(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})(x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim = i, then index must be an nnn -dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1)(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})(x0,x1,...,xi−1,y,xi+1,...,xn−1) where y≥1y \geq 1y≥1 and out will have the same size as index.
接下来举个例子:
import torch
b = torch.Tensor([[1, 2, 3], [4, 5, 6]])
index_1 = torch.LongTensor([[0, 1], [2, 0]])
print(torch.gather(b, dim=1, index=index_1))
# 输出
tensor([[1., 2.],
[6., 4.]])
接下来根据文档计算一下结果的输出,out[0][0] = input[0][index[0][0]] = input[0][0] = 1
out[0][1] = input[0][index[0][1]] = input[0][1] = 2
out[1][0] = input[1][index[1][0]] = input[1][2] = 6
out[1][1] = input[1][index[1][1]] = input[1][0] = 4