【Pytorch】index_select和gather函数的对比

在Pytorch中,index_selectgather均是被用于张量选取的常用函数,本文通过实例来对比这两个函数。

1. index_select

沿着张量的某个dim方向,按照index规定的选取指定的低一维度张量元素整体,在拼接成一个张量。其官方解释如下:

torch.index_select(input, dim, index, out=None) 
"""
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor
"""

先简单看两个示例:
示例1:沿着dim=0的方向进行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=0, index=torch.tensor([0,1,0,1]))
# b为tensor([[1, 2, 3],
#        [4, 5, 6],
#        [1, 2, 3],
#        [4, 5, 6]])

在这里插入图片描述
显然,对于二维张量,dim=0意味着按照index的编号选取指定的行,拼接成目标张量。其返回值仍保持和原始张量相同的ndim

示例2:沿着dim=1的方向进行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=1, index=torch.tensor([1,1]))
# b为tensor([[2, 2],
#          [5, 5]])

在这里插入图片描述
对于二维张量,dim=1意味着按照index的编号选取指定的列,拼接成目标张量。其返回值仍保持和原始张量相同的ndim

根据上述两个例子,可见index_select的作用间接明了,即选取某个dim上的若干个元素,将其拼接为目标张量。其中index为一个一维张量,表明该dim上做选取的具体元素,返回张量与原张量的ndim一致。

2. gather

相较于index_selectgather就显得让人难以理解的多。个人理解,其操作相当于用于沿着张量的某个dim方向,按照index规定的选取指定元素,构成该为维度上的每个子张量,最后拼接成一个张量。其官方解释如下:

torch.gather(input, dim, index, out=None, sparse_grad=False)

"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
"""

是不是还是令人费解?我们先以两个2维张量的例子来说明:

示例1:沿着dim=1的方向进行选择

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=1, index=torch.tensor([[2,0,2,1], [1,1,0,0]]))
# 返回值为 tensor([[3, 1, 3, 2],
#         [5, 5, 4, 4]])

其操作过程可参照下图:
在这里插入图片描述
由上图可见,dim=1表示在二维张量中,以行为单位,对每行中的元素,按照index的索引号进行选取,再拼接到一起。从张量shape上看,其在dim=0上保持一致,对dim=1进行了放大或缩小。

对于更一般的张量,gather的过程可理解为沿着dim维的size,对各个子张量进行选取和重新的拼接,因此其返回值和原始张量的ndim是相同的。

示例2:沿着dim=0的方向进行选择

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=0, index=torch.tensor([[0, 1, 0], [1,0,1], [0, 0, 0],[1,1,1]]))
# 返回值为 tensor([[1, 5, 3],
#        [4, 2, 6],
#        [1, 2, 3],
#        [4, 5, 6]])

其操作过程可参照下图:
在这里插入图片描述
对于二维张量,其操作过程与dim=1相反,即以行为单位,对每列中的元素,按照index的索引号进行选取,再拼接到一起。从张量shape上看,其在dim=1上保持一致,对dim=0进行了放大或缩小。

示例3:三维张量的例子

a = torch.arange(24).reshape(2,3,4)
# a为tensor([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

b = torch.gather(a, dim=2, index=torch.tensor([[[2], [1], [0]], [[1], [2], [3]]]))

# b为tensor([[[ 2],
#            [5],
#            [8]],

#        [[13],
#        [18],
#        [23]]])

简单解释下,其选取的dim=2,即沿着三维张量最内层的张量进行元素选取和拼接,其只选取了一次。因此,上述操作可理解为每个最内层选取一个元素。

3. 总结

index_selectgather虽然都可用于张量元素的选取和重塑,主要参数的命名也类似,但其功能截然不同。简要而言:
(1)index_select 用于对dim方向各子张量的整体选取和拼接,其中的index为一维张量;
(2)gather用于对dim方向各子张量的元素在其它维度方向上的选取和拼接,其中的index为与原张量同ndim的张量。

### 回答1: 我可以提供一些建议,但无法保证它可以在没有改变原模型的情况下替换 torch.gather 操作。首先,可以尝试使用 torch.index_select 来实现同样的功能。其次,可以考虑使用 torch.masked_select 来替换 torch.gather,具体可以参考 PyTorch 的文档。 ### 回答2: 在PyTorch中,如果想要在不改变原模型的情况下替换forward函数中的torch.gather操作,可以使用torch.index_select函数来实现相同的功能。torch.index_select函数接受一个tensor和一个维度索引作为参数,返回按照指定维度索引的元素。 首先,我们需要理解torch.gather操作的作用。torch.gather可以按照指定的维度,在一个tensor中进行索引,并返回相应的值。例如,对于一个大小为(3, 4)的tensor,我们可以通过torch.gather(tensor, 0, index)来按照第0个维度的索引index来获取对应值。 下面是一个示例代码,展示如何使用torch.index_select替换forward函数中的torch.gather操作: ```python import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.weights = nn.Parameter(torch.randn(3, 4)) def forward(self, index): # 使用torch.gather操作 output = torch.gather(self.weights, 0, index) return output def replace_forward(self, index): # 使用torch.index_select替换torch.gather操作 output = torch.index_select(self.weights, 0, index) return output ``` 在上面的示例代码中,MyModel类的forward函数中使用了torch.gather操作,而replace_forward函数中则使用了torch.index_select来实现相同的功能。这样,我们可以在不改变原模型的情况下替换forward函数中的torch.gather操作。 ### 回答3: 在不改变原模型的情况下,我们可以通过使用其他的操作来替换`torch.gather`。 `torch.gather`操作通常用于根据索引从输入的张量中提取特定元素。它的一般形式是`torch.gather(input, dim, index, out=None)`,其中`input`是输入张量,`dim`是提取索引的维度,`index`是包含提取索引的张量。 为了替换`torch.gather`操作,我们可以使用`torch.index_select`和`torch.unsqueeze`来实现相似的功能。 首先,我们可以使用`torch.index_select`操作来选择指定维度上的索引。这个操作的一般形式是`torch.index_select(input, dim, index, out=None)`,其中`input`是要选择的张量,`dim`是选择的维度,`index`是包含索引的一维张量。 然后,我们可以使用`torch.unsqueeze`操作来在选择的维度上增加一个维度。这个操作的一般形式是`torch.unsqueeze(input, dim, out=None)`,其中`input`是要增加维度的张量,`dim`是要增加的维度。 综上所述,为了替换`torch.gather`操作,我们可以使用以下代码: ```python import torch class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, input, index): # 替换 torch.gather 的操作 output = torch.index_select(input, 1, index.unsqueeze(1)).squeeze(1) return output ``` 在上面的代码中,我们使用`torch.index_select`选择了指定维度`dim=1`上的索引,并使用`torch.unsqueeze`增加了一个维度。最后,我们使用`squeeze`操作将这个额外的维度去除,以匹配`torch.gather`操作的输出。 这样,我们就在不改变原模型的情况下替换了`torch.gather`操作,实现了相似的功能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值