函数定义
torch.gather(input, dim, index, out=None)
其中,
input(tensor): 待操作数,源张量
dim(int): 维度。
index(LongTensor): 索引下标。
out: 输出的张量
注意,输出的张量和index的size是一致的。
举例
dim=0
输入:
a = torch.arange(0,16).view(4,4)
print(a)
index = torch.LongTensor([[0,1,2,3]]) # index的元素必须是int型数据,如果使用torch.Tensor会报错
print(a.gather(0, index))
输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 5, 10, 15]])
解释:dim=0代表针对0维,即输出为行形式。不难看出,此时index的[0,1,2,3]分别指向源张量a的第0,1,2,3行的第0,1,2,3列的数据。被指向的数据分别为0,5,10,15,又因为dim=0代表“行”,故输出tensor([[ 0, 5, 10, 15]])
dim=1
输入:
index3 = torch.LongTensor([[0,1,2,3]]).t() # 添加上t.()将index3变成一列的数据
print(a.gather(1,index3))
输出:
tensor([[ 0],
[ 5],
[10],
[15]])
解释:dim=0代表针对0维,即输出为行形式。不难看出,此时index的[0,1,2,3]分别指向源张量a的第0,1,2,3列的第0,1,2,3行的数据。被指向的数据分别为0,5,10,15,又因为dim=1代表“列”,故有上述输出。
dim=2
输入:
a = torch.randint(0, 30, (2, 3, 5))
print("------a------")
print(a)
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print("----index----")
print(index)
c = torch.gather(a, 2,index) # 除了使用a.gather(1,index)外,也可以torch.gather(a,1,index)
print("------c------")
print(c)
输出:
------a------
tensor([[[ 2, 0, 28, 28, 12],
[28, 28, 13, 26, 13],
[16, 7, 26, 21, 3]],
[[19, 16, 29, 20, 23],
[10, 26, 4, 24, 26],
[26, 14, 28, 3, 25]]])
----index----
tensor([[[0, 1, 2, 0, 2],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[2, 2, 2, 2, 2]]])
------c------
tensor([[[ 2, 0, 28, 2, 28],
[28, 28, 28, 28, 28],
[ 7, 7, 7, 7, 7]],
[[16, 29, 29, 29, 29],
[10, 10, 10, 10, 10],
[28, 28, 28, 28, 28]]])
解释:
我们可以看index[0][0],为[0,1,2,0,2],其中的“0 1 2 0 2”五个数字分别指向源张量第0页第0行的第“0 1 2 0 2”列数据,故输出张量c[0][0]=[2,0,28,2,28]。类推即可理解dim=2的输出情况。