【Pytorch】index_select和gather函数的对比

在Pytorch中,index_selectgather均是被用于张量选取的常用函数,本文通过实例来对比这两个函数。

1. index_select

沿着张量的某个dim方向,按照index规定的选取指定的低一维度张量元素整体,在拼接成一个张量。其官方解释如下:

torch.index_select(input, dim, index, out=None) 
"""
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor
"""

先简单看两个示例:
示例1:沿着dim=0的方向进行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=0, index=torch.tensor([0,1,0,1]))
# b为tensor([[1, 2, 3],
#        [4, 5, 6],
#        [1, 2, 3],
#        [4, 5, 6]])

在这里插入图片描述
显然,对于二维张量,dim=0意味着按照index的编号选取指定的行,拼接成目标张量。其返回值仍保持和原始张量相同的ndim

示例2:沿着dim=1的方向进行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=1, index=torch.tensor([1,1]))
# b为tensor([[2, 2],
#          [5, 5]])

在这里插入图片描述
对于二维张量,dim=1意味着按照index的编号选取指定的列,拼接成目标张量。其返回值仍保持和原始张量相同的ndim

根据上述两个例子,可见index_select的作用间接明了,即选取某个dim上的若干个元素,将其拼接为目标张量。其中index为一个一维张量,表明该dim上做选取的具体元素,返回张量与原张量的ndim一致。

2. gather

相较于index_selectgather就显得让人难以理解的多。个人理解,其操作相当于用于沿着张量的某个dim方向,按照index规定的选取指定元素,构成该为维度上的每个子张量,最后拼接成一个张量。其官方解释如下:

torch.gather(input, dim, index, out=None, sparse_grad=False)

"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
"""

是不是还是令人费解?我们先以两个2维张量的例子来说明:

示例1:沿着dim=1的方向进行选择

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=1, index=torch.tensor([[2,0,2,1], [1,1,0,0]]))
# 返回值为 tensor([[3, 1, 3, 2],
#         [5, 5, 4, 4]])

其操作过程可参照下图:
在这里插入图片描述
由上图可见,dim=1表示在二维张量中,以行为单位,对每行中的元素,按照index的索引号进行选取,再拼接到一起。从张量shape上看,其在dim=0上保持一致,对dim=1进行了放大或缩小。

对于更一般的张量,gather的过程可理解为沿着dim维的size,对各个子张量进行选取和重新的拼接,因此其返回值和原始张量的ndim是相同的。

示例2:沿着dim=0的方向进行选择

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=0, index=torch.tensor([[0, 1, 0], [1,0,1], [0, 0, 0],[1,1,1]]))
# 返回值为 tensor([[1, 5, 3],
#        [4, 2, 6],
#        [1, 2, 3],
#        [4, 5, 6]])

其操作过程可参照下图:
在这里插入图片描述
对于二维张量,其操作过程与dim=1相反,即以行为单位,对每列中的元素,按照index的索引号进行选取,再拼接到一起。从张量shape上看,其在dim=1上保持一致,对dim=0进行了放大或缩小。

示例3:三维张量的例子

a = torch.arange(24).reshape(2,3,4)
# a为tensor([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

b = torch.gather(a, dim=2, index=torch.tensor([[[2], [1], [0]], [[1], [2], [3]]]))

# b为tensor([[[ 2],
#            [5],
#            [8]],

#        [[13],
#        [18],
#        [23]]])

简单解释下,其选取的dim=2,即沿着三维张量最内层的张量进行元素选取和拼接,其只选取了一次。因此,上述操作可理解为每个最内层选取一个元素。

3. 总结

index_selectgather虽然都可用于张量元素的选取和重塑,主要参数的命名也类似,但其功能截然不同。简要而言:
(1)index_select 用于对dim方向各子张量的整体选取和拼接,其中的index为一维张量;
(2)gather用于对dim方向各子张量的元素在其它维度方向上的选取和拼接,其中的index为与原张量同ndim的张量。

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值