引子
gather,其英文意思是“聚集,收集”,大概意思应该为收集相关元素。Pytorch(1.12 )文档解释得很简单:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Gathers values along an axis specified by dim.
还给了个例子:
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
看解释的时候好像一下就懂了,看例子的时候一下就蒙B了。然后就查了些资料, 这里写下自己的理解。
理解
我遇到的主要问题是为什么index的维度数与input的维度数是一样的?
首先,index中的元素是dim维度上的索引,而提取元素的操作需要该元素每个维度上的值,仅仅根据一个dim参数上的值肯定是不行的,因此,就需要额外的信息反映出除了dim之外其他维度值,所以就出现了“index的维度数与input的维度数是一样的”。
如何从这一点获取其他维度的信息呢?
例子:2D tensor
k = torch.arange(0, 6).view((2, 3))
print(k)
tensor([[0, 1, 2],
[3, 4, 5]])
index = torch.randint(3, (2,3), dtype=torch.long)
k1 = torch.gather(k, 1, index)
print(index)
print(k1)
tensor([[2, 1, 1],
[0, 0, 2]])
tensor([[2, 1, 1],
[3, 3, 5]])
从index角度来解释,以index第一个元素index[0,0]即2为例。要收集的元素第1维上的值肯定是2(index[0,0]),那么第0维怎么确定呢? 答案就是2(index[0,0])这个值所在位置的第0维,这就是为什么index的维度数要和input维度数相同的原因,即待收集元素的其余维度上的值用index元素(“2”)位置信息的其余维度的值表示。因此,index其余每个维度的取值大小可以小于或等于input。
示意图如下:
这么一想,是不是Pytorch(1.12 )文档列举的例子就明白了呢:
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2