pytorch的gather函数的一些理解
先给出官方文档的解释,我觉得官方的文档写的已经很清楚了,四个参数分别是input,dim,index,out,输出的tensor是以index为大小的tensor。
其中,这就是最关键的定义
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
主要解释一下dim,dim=0的时候,把index的元素放入行进行索引,有一点需要注意的是,参数index的tensor格式是除了第1维也就是行那一维之外,其他维的格式需与input保持一致!下面给个例子
import torch
a = torch.arange(0, 16).view(4,4)
index = torch.LongTensor([[0,1,2,3]])
b = a.gather(0, index)
print(a)
print(index)
print(b)
#形象的理解就是在每一列的第index[]上进行索引
for j in range(4):
print(a[index[0][j]][j].item())
--------------------------------------------------------------------
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[0, 1, 2, 3]])
tensor([[ 0, 5, 10, 15]])
0
5
10
15
dim = 1的时候,把index的元素放入列进行索引,有一点需要注意的是,参数index的tensor格式是除了第2维也就是列那一维之外,其他维的格式需与input保持一致!下面给个例子
import torch
a = torch.arange(0, 16).view(4,4)
index = torch.LongTensor([[0],[1],[2],[3]])
b = a.gather(1, index)
print(a)
print(index)
print(b)
#形象的理解就是在每一行的第index[]列上进行索引
for j in range(4):
print(a[j][index[j][0]].item())
--------------------------------------------------------------------
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[0],
[1],
[2],
[3]])
tensor([[ 0],
[ 5],
[10],
[15]])
0
5
10
15
本人对矩阵的一些概念还有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的张量一开始很难处理清楚,还需慢慢来。