pytorch中的gather函数
可以看成是将 input中的值进行挑选后赋给output,挑选规则为:output中的第i行第j列的的值是input中第Index(i,j)行,第j列的值,此时dim=0,相当于列不变;或者是第i行,第Index(i,j)列的值,此时dim=1。Index的维度与output相同。
换一种思路,我们可以理解为根据我们想要的output,观察input,设计index
Input= 1 2 3
4 5 6
我们希望 output= 1 2
6 4
按行gather:
DIM=0 希望output(0,0)=1, 1在input的(0,0)位置,故希望index(0,0)=0
希望output(0,1)=2, 2在input的(0,1)位置,故希望index(0,1)=0
.。。。。。。
DIM=1 希望output(0,0)=1, 1在input的(0,0)位置,故希望index(0,0)=0
希望output(0,1)=2, 2在input的(0,1)位置,故希望index(0,1)=1
希望output(1,0)=6, 6在input的(1,2)位置,故希望index(1,0)=2
希望output(1,1)=4, 4在input的(1,0)位置,故希望index(1,1)=0
.等等
Index = 0 1
2 0
总结,实际用途dim=1,我们想要output[i][j]对应input中i行m列的位置(同一行中找),我们就让Index(i,j)=m
Out[i][j]=input[i,index[i,j]],其中index[i,j]=m
dim=0,我们想要output[i][j]对应input中m行j列的位置(同一列中找),我们就让Index(i,j)=m
Out[i][j]=input[index[i,j],j],其中index[i,j]=m
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3