记录一下torch.gather
函数
用法:torch.gather(input: Tensor, dim: int, index: LongTensor, *, sparse_grad=False, out=None) -> Tensor
功能:指定张量index
,根据其元素的值来获取输入矩阵input
上的值。
注意
index
需要与input
有相同的维度,并且对d!=dim
时要求index.size(d)<=input.size(d)
意思就是说如果input的size为(2, 3, 4),如果dim指定为1,那么需要index.size(0)<=2以及index.size(2)<=4.- 函数输出的Tensor与index的shape相同
举个例子:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
怎么得到这个结果的呢,可以这样记忆:index
现在是[[0, 0], [1, 0]]
,它的每个元素在index
中都有其索引,比如元素1
索引是[1, 0](index[1, 0]=1)
,由于现在指定的dim=1
,那么就用1
代替[1, 0]
中dim=1
处的0
,变成[1, 1]
,即获取到input[1, 1]
,如下图所示。
对于多维矩阵也是一样的流程,用index的每个元素的值代替该元素在index上的索引在dim维度上的值,便能得到在input上的索引。
也就是官方举的例子:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
再有一个例子
>>> input_ = [[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]]
>>> input_ = torch.tensor(input_)
>>> index = torch.LongTensor([[3],[2],[4],[0]])
>>> torch.gather(input_, 1, index)
tensor([[5],
[3],
[7],
[1]])