pytorch gather函数理解
注释中的说法如下
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=2
看起来有点不好理解,因此做了如下实验得出结论
1. 对index的shape的要求:指定dim后,index中的除指定的dim以外的其他维度大小必须和input中对应的维度大小相同。不满足这个就会报错。
2. 结果是如何计算出来的?
注释中的三维数组的不好理解,看下二维的
例如:input的结构是[i,j],dim=0, index.shape是[INDi,INDj]其中INDj必须等于j,计算的结果是input[i]index[INDi][INDj],这里通过index[INDi,INDj]来确定在input中取值时j的索引。
取值的顺序也是按照index的顺序进行的,先行后列。代码执行如下
In[2]: import torch as tc
In[3]: input =tc.arange(1,17).view(4,4)
In[4]: input
Out[4]:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
In[5]: index = tc.LongTensor([[0,1,0,0]])
In[6]: output = input.gather(0,index)
In[7]: output
Out[7]: tensor([[1, 6, 3, 4]])
因为dim=0,在二维数组中行的下标取值就来源于index,
index[0][0]=0,那么取input[0][0]=1
index[0][1]=1,那么取input[1][1]=6
index[0][2]=0,那么取input[0][2]=3
index[0][3]=0,那么取input[0][3]=4
index = tc.LongTensor([[0,1,0,0],[1,1,1,1]])
output = input.gather(0,index)
以下其他代码可以自己在console中执行以下看效果。
index = tc.LongTensor([[0,1,0,0],[1,1,1,1]])
output = input.gather(0,index)
index = tc.LongTensor([[1],[2],[3],[0]])
output = input.gather(1,index)
index = tc.LongTensor([[1,0],[2,3],[3,1],[0,1]])
output = input.gather(1,index)
3. index的shape和output的shape相同