引言
陈述
torch.gather(input, dim, index, *) → tensor是一个从给定的input中按照给定的index沿着给定的dim逐一取出元素(element-wise)的操作。这意味着当我们的input涉及到多个维度的时候,创建index的开销和操作的开销也随之增大。这个时候,利用tensor的高级索引功能才是更适合的选择。
一句话总结
gather只适合collect部分element的情况,而不适合批次操作。
问题
gather的直白理解
import torch
data = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
indices = torch.tensor([[0, 2, 1],
[1, 0, 2],
[2, 1, 0]])
result = torch.gather(data, 1, indices)
'''
result:
tensor([[1, 3, 2],
[5, 4, 6],
[9, 8, 7]])
'''
在这个例子中,input是一个3×3的矩阵,同时index也是一个3×3的矩阵,然后dim为1。
注意这个dim,这是gather最容易混淆的点!
dim为1,我们可以说:“gather是按行索引”,同时我们也可以说:“gather是沿着列索引的”。(后者更标准)
初学者常常因为第一个表达而疑惑,但我们将步骤拆分开:
- 第一行按照索引
[0, 2, 1]
,得到[1, 3, 2]
- 第二行按照索引
[1, 0, 2]
,得到[5, 4, 6]
- 第三行按照索引
[2, 1, 0]
,得到[9, 8, 7]
用大白话来说,index中一行的元素,都逐一指向了data当中对应行的各列元素,比如说index中第一行第二列的2就指向了data的第一行的第三个元素。
所以“按行索引”和“沿着列索引”的表达是等价的,因为gather是element-wise的操作。
局限性
此时,考虑一个情况,我们给定一个二维tensor的形状为(n, m),即n个句子,然后每个句子有m个单词,然后考虑其被embedding至三维,其形状为(n, m, d),即每个单词都被embedding至形状为[d]的tensor。
这个时候,如果我们使用gather去抽取一部分单词,这就不合理了。
原因是因为我们需要将tokens给扩展到同样的维度去,比如说indices的形状为(n, k),k <= m,然后我们还需要扩展其为(n, k, d),这一步开销很大,然后因为是element-wise的,所以计算复杂度也变成了n×k×d。
替代方案
高级索引会更合理,我们可以忽略掉第三维,单纯对前两维进行索引。
我们需要一个创建一个表示行坐标的tensor,即:
batch_indices = torch.arange(n).unsqueeze(1).expand(-1, m)
result = embed_data[batch_indices, indices]
这里,两个索引的形状都是(n, m),也就是说,他们会逐元素组合以对应出每一个元素(不考虑第三维)的位置,也就是按照dim=1进行索引了。为了方便理解,可以把indices看作成如下形态:
- 第一行索引
[(0,0), (0,2), (0,1)]
- 第二行索引
[(1, 1), (1, 0), (1, 2)]
- 第三行索引
[(2, 2), (2, 1), (2, 0)]
此时,便能节约大量的开销。
不过这里还存在一个与底层实现有关的疑惑点,gather在同一维度下的操作速度还是会慢于高级索引操作,哪怕业务层逻辑上两者有相同的计算复杂度。
这也导致,在如上我们规定的情况中,节约的时间成本其实远远大于d倍。