对于Pytorch中的gather方法的直白理解,包括局限性以及代替方案

引言

陈述

torch.gather(input, dim, index, *) → tensor是一个从给定的input中按照给定的index沿着给定的dim逐一取出元素(element-wise)的操作。这意味着当我们的input涉及到多个维度的时候,创建index的开销和操作的开销也随之增大。这个时候,利用tensor的高级索引功能才是更适合的选择。

一句话总结

gather只适合collect部分element的情况,而不适合批次操作。

问题

gather的直白理解

import torch

data = torch.tensor([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])

indices = torch.tensor([[0, 2, 1],
                        [1, 0, 2],
                        [2, 1, 0]])

result = torch.gather(data, 1, indices)

'''
result:
tensor([[1, 3, 2],
        [5, 4, 6],
        [9, 8, 7]])
'''

在这个例子中,input是一个3×3的矩阵,同时index也是一个3×3的矩阵,然后dim为1。

注意这个dim,这是gather最容易混淆的点!

dim为1,我们可以说:“gather是按行索引”,同时我们也可以说:“gather是沿着列索引的”。(后者更标准)

初学者常常因为第一个表达而疑惑,但我们将步骤拆分开:

  • 第一行按照索引 [0, 2, 1],得到 [1, 3, 2]
  • 第二行按照索引 [1, 0, 2],得到 [5, 4, 6]
  • 第三行按照索引 [2, 1, 0],得到 [9, 8, 7]

用大白话来说,index中一行的元素,都逐一指向了data当中对应行的各列元素,比如说index中第一行第二列的2就指向了data的第一行的第三个元素。

所以“按行索引”和“沿着列索引”的表达是等价的,因为gather是element-wise的操作。

局限性

此时,考虑一个情况,我们给定一个二维tensor的形状为(n, m),即n个句子,然后每个句子有m个单词,然后考虑其被embedding至三维,其形状为(n, m, d),即每个单词都被embedding至形状为[d]的tensor。

这个时候,如果我们使用gather去抽取一部分单词,这就不合理了。

原因是因为我们需要将tokens给扩展到同样的维度去,比如说indices的形状为(n, k),k <= m,然后我们还需要扩展其为(n, k, d),这一步开销很大,然后因为是element-wise的,所以计算复杂度也变成了n×k×d。

替代方案

高级索引会更合理,我们可以忽略掉第三维,单纯对前两维进行索引。

我们需要一个创建一个表示行坐标的tensor,即:

batch_indices = torch.arange(n).unsqueeze(1).expand(-1, m)
result = embed_data[batch_indices, indices]

这里,两个索引的形状都是(n, m),也就是说,他们会逐元素组合以对应出每一个元素(不考虑第三维)的位置,也就是按照dim=1进行索引了。为了方便理解,可以把indices看作成如下形态:

  • 第一行索引 [(0,0), (0,2), (0,1)]
  • 第二行索引 [(1, 1), (1, 0), (1, 2)]
  • 第三行索引 [(2, 2), (2, 1), (2, 0)]

此时,便能节约大量的开销。

不过这里还存在一个与底层实现有关的疑惑点,gather在同一维度下的操作速度还是会慢于高级索引操作,哪怕业务层逻辑上两者有相同的计算复杂度。

这也导致,在如上我们规定的情况中,节约的时间成本其实远远大于d倍。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值