torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
input和index都是tensor,值得注意的是dim这个参数,意味着在对应维度位置上取index的值,可参考对dim参数的说明,即dim是会变为1的那个维度,或者说是应用index的维度。
具体说明可参考该博客
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
下面看一个二维的例子
b = torch.Tensor([[1, 2, 3], [4, 5, 6]])
print(b)
index_1 = torch.LongTensor([[0], [2]])
index_2 = torch.LongTensor([[1, 1,1]])
print( torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))
结果为,可以看到第一个gather的dim=1的维度消失了,由[2,3]变为[2,1],第二个gather的dim=0的维度消失了,由[2,3]变为[1,3]
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[1.],
[6.]])
tensor([[4., 5., 6.]])