torch.gather() 用法解读

torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor

沿\(dim\)指定的轴和\(index\)指定的索引从\(input\)中提取对应的值。

对于一个三维张量

\(output[i][j][k]=input[index[i][j][k]][j][k]  \quad \#\enspace if \enspace dim==0\) 

\(output[i][j][k]=input[i][index[i][j][k]][k]  \quad \#\enspace if \enspace dim==1\) 

\(output[i][j][k]=input[i][j][index[i][j][k]]  \quad \#\enspace if \enspace dim==2\) 

\(input\)和\(index\)的\(dimensions\)数目必须相同。 \(out\)和\(index\)的\(shape\)是相同的。(注意\(dimensions\)和\(shape\)的区别)

示例

下面用两个例子来解释一下具体的用法

例1

import torch

dim = 0
_input = torch.tensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])
index = torch.tensor([[0, 1, 2],
                      [1, 2, 0]])

output = torch.gather(_input, dim, index)

print(output)
# tensor([[10, 14, 18],
#         [13, 17, 12]])

该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(2, 3)。

因为dim=0,index中的每个数其值代表dim=0即"行"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1, 2]分别表示第0、1、2行,而这三个数本身在dim=1维度的索引为0、1、2即第0、1、2列。因此第一个数0定位到_input中的第0行,而0本身在index中的第0列,因此又定位到_input的第0列,这样就找到了10这个数,同理找到14和18。

index中的第1行[1, 2, 0]分别表示_input中的第1、2、0行和第0、1、2列,因此找到_input中对应的数[13, 17, 12]。

例2

import torch

dim = 1
_input = torch.tensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])
index = torch.tensor([[0, 1],
                      [1, 2], 
                      [2, 0]])

output = torch.gather(_input, dim, index)
print(output)
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])

该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(3, 2)。

因为dim=1,index中的每个数其值代表dim=1即"列"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1]分别表示第0、1列,而这三个数本身在dim=0维度的索引为0即第0行。因此第一个数0定位到_input中的第0列,而0本身在index中的第0行,因此又定位到_input的第0行,这样就找到了10这个数,同理找到11。

index中的第1行[1, 2]分别表示_input中的第1、2列和第1行,因此找到_input中对应的数[14, 15]。

index中的第2行[2, 0]分别表示_input中的第2、0列和第2行,因此找到_input中对应的数[18, 16]。

总结

上面的示例是二维的情况,同理也可以推广到三维甚至更多维。总结来说,index中每个数其本身的值表示参数dim指定维度的索引,而其它的每个维度都由每个数在index中的对应维度的索引指定。

参考

torch.gather — PyTorch 1.12 documentation

python - What does the gather function do in pytorch in layman terms? - Stack Overflow

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值