pytorch gather函数理解

pytorch gather函数理解

注释中的说法如下
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0

out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1

out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=2

看起来有点不好理解,因此做了如下实验得出结论

1. 对index的shape的要求:指定dim后,index中的除指定的dim以外的其他维度大小必须和input中对应的维度大小相同。不满足这个就会报错。
2. 结果是如何计算出来的?
注释中的三维数组的不好理解,看下二维的
例如:input的结构是[i,j],dim=0, index.shape是[INDi,INDj]其中INDj必须等于j,计算的结果是input[i]index[INDi][INDj],这里通过index[INDi,INDj]来确定在input中取值时j的索引。
取值的顺序也是按照index的顺序进行的,先行后列。代码执行如下

In[2]: import torch as tc
In[3]: input =tc.arange(1,17).view(4,4) 
In[4]: input
Out[4]: 
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])
In[5]: index = tc.LongTensor([[0,1,0,0]])
In[6]: output = input.gather(0,index)
In[7]: output
Out[7]: tensor([[1, 6, 3, 4]])

因为dim=0,在二维数组中行的下标取值就来源于index,
index[0][0]=0,那么取input[0][0]=1
index[0][1]=1,那么取input[1][1]=6
index[0][2]=0,那么取input[0][2]=3
index[0][3]=0,那么取input[0][3]=4
index = tc.LongTensor([[0,1,0,0],[1,1,1,1]])
output = input.gather(0,index)

以下其他代码可以自己在console中执行以下看效果。

index = tc.LongTensor([[0,1,0,0],[1,1,1,1]])
output = input.gather(0,index)

index = tc.LongTensor([[1],[2],[3],[0]])
output = input.gather(1,index)

index = tc.LongTensor([[1,0],[2,3],[3,1],[0,1]])
output = input.gather(1,index)


3. index的shape和output的shape相同

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值