先看官方文档
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
x
tensor([[ 0.3607, -0.2859, -0.3938],
[ 0.2429, -1.3833, -2.3134]])
index = torch.LongTensor([[0, 1, 1]])
torch.gather(x, 0, index)
tensor([[ 0.3607, -1.3833, -2.3134]])
torch.gather(x, 1, index)
tensor([[ 0.3607, -0.2859, -0.2859]])
结合官方文档及实例,
当dim=0时,ouput[i][j]中的行、列索引确认逻辑如下:列索引不变,行索引是index[i][j]出的值
当dim=1时,ouput[i][j]中的行、列索引确认逻辑如下:行索引不变,列索引是index[i][j]出的值,上面的行索引是0,即取第0行第(0,1,1)列的值(0.3607, -0.2859, -0.2859)
index = torch.LongTensor([[0, 1, 1], [1, 1, 1]])
torch.gather(x, 1, index)
tensor([[ 0.3607, -0.2859, -0.2859],
[-1.3833, -1.3833, -1.3833]])
dim=1,行索引不变,第0行输出值从input第0行找,(0,1,1)对应(0.3607, -0.2859, -0.2859)
第1行输出从input第1行找,(1,1,1)对应(-1.3833, -1.3833, -1.3833)