在Pytorch中,index_select
和gather
均是被用于张量选取的常用函数,本文通过实例来对比这两个函数。
1. index_select
沿着张量的某个dim
方向,按照index
规定的选取指定的低一维度张量元素整体,在拼接成一个张量。其官方解释如下:
torch.index_select(input, dim, index, out=None)
"""
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor
"""
先简单看两个示例:
示例1:沿着dim=0
的方向进行
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.index_select(a, dim=0, index=torch.tensor([0,1,0,1]))
# b为tensor([[1, 2, 3],
# [4, 5, 6],
# [1, 2, 3],
# [4, 5, 6]])
显然,对于二维张量,dim=0
意味着按照index
的编号选取指定的行,拼接成目标张量。其返回值仍保持和原始张量相同的ndim
。
示例2:沿着dim=1
的方向进行
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.index_select(a, dim=1, index=torch.tensor([1,1]))
# b为tensor([[2, 2],
# [5, 5]])
对于二维张量,dim=1
意味着按照index
的编号选取指定的列,拼接成目标张量。其返回值仍保持和原始张量相同的ndim
。
根据上述两个例子,可见index_select
的作用间接明了,即选取某个dim
上的若干个元素,将其拼接为目标张量。其中index
为一个一维张量,表明该dim
上做选取的具体元素,返回张量与原张量的ndim
一致。
2. gather
相较于index_select
,gather
就显得让人难以理解的多。个人理解,其操作相当于用于沿着张量的某个dim
方向,按照index
规定的选取指定元素,构成该为维度上的每个子张量,最后拼接成一个张量。其官方解释如下:
torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
"""
是不是还是令人费解?我们先以两个2维张量的例子来说明:
示例1:沿着dim=1
的方向进行选择
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.gather(input=a, dim=1, index=torch.tensor([[2,0,2,1], [1,1,0,0]]))
# 返回值为 tensor([[3, 1, 3, 2],
# [5, 5, 4, 4]])
其操作过程可参照下图:
由上图可见,dim=1
表示在二维张量中,以行为单位,对每行中的元素,按照index
的索引号进行选取,再拼接到一起。从张量shape
上看,其在dim=0
上保持一致,对dim=1
进行了放大或缩小。
对于更一般的张量,gather
的过程可理解为沿着dim
维的size
,对各个子张量进行选取和重新的拼接,因此其返回值和原始张量的ndim
是相同的。
示例2:沿着dim=0
的方向进行选择
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a为tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.gather(input=a, dim=0, index=torch.tensor([[0, 1, 0], [1,0,1], [0, 0, 0],[1,1,1]]))
# 返回值为 tensor([[1, 5, 3],
# [4, 2, 6],
# [1, 2, 3],
# [4, 5, 6]])
其操作过程可参照下图:
对于二维张量,其操作过程与dim=1
相反,即以行为单位,对每列中的元素,按照index
的索引号进行选取,再拼接到一起。从张量shape
上看,其在dim=1
上保持一致,对dim=0
进行了放大或缩小。
示例3:三维张量的例子
a = torch.arange(24).reshape(2,3,4)
# a为tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
b = torch.gather(a, dim=2, index=torch.tensor([[[2], [1], [0]], [[1], [2], [3]]]))
# b为tensor([[[ 2],
# [5],
# [8]],
# [[13],
# [18],
# [23]]])
简单解释下,其选取的dim=2
,即沿着三维张量最内层的张量进行元素选取和拼接,其只选取了一次。因此,上述操作可理解为每个最内层选取一个元素。
3. 总结
index_select
和gather
虽然都可用于张量元素的选取和重塑,主要参数的命名也类似,但其功能截然不同。简要而言:
(1)index_select
用于对dim
方向各子张量的整体选取和拼接,其中的index
为一维张量;
(2)gather
用于对dim
方向各子张量的元素在其它维度方向上的选取和拼接,其中的index
为与原张量同ndim
的张量。