pytorch中的gather函数

pytorch中的gather函数
可以看成是将 input中的值进行挑选后赋给output,挑选规则为:output中的第i行第j列的的值是input中第Index(i,j)行,第j列的值,此时dim=0,相当于列不变;或者是第i行,第Index(i,j)列的值,此时dim=1。Index的维度与output相同。

换一种思路,我们可以理解为根据我们想要的output,观察input,设计index
Input= 1 2 3
4 5 6
我们希望 output= 1 2
6 4
按行gather:
DIM=0 希望output(0,0)=1, 1在input的(0,0)位置,故希望index(0,0)=0
希望output(0,1)=2, 2在input的(0,1)位置,故希望index(0,1)=0
.。。。。。。
DIM=1 希望output(0,0)=1, 1在input的(0,0)位置,故希望index(0,0)=0
希望output(0,1)=2, 2在input的(0,1)位置,故希望index(0,1)=1
希望output(1,0)=6, 6在input的(1,2)位置,故希望index(1,0)=2
希望output(1,1)=4, 4在input的(1,0)位置,故希望index(1,1)=0
.等等
Index = 0 1
2 0
总结,实际用途dim=1,我们想要output[i][j]对应input中i行m列的位置(同一行中找),我们就让Index(i,j)=m
Out[i][j]=input[i,index[i,j]],其中index[i,j]=m

dim=0,我们想要output[i][j]对应input中m行j列的位置(同一列中找),我们就让Index(i,j)=m
Out[i][j]=input[index[i,j],j],其中index[i,j]=m

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

 1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]


 1  2
 6  4
[torch.FloatTensor of size 2x2]


 1  5  6
 1  2  3


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值