Pytorch之gather的用法

首先了解下函数中的参数:
torch.gather(input, dim, index, out=None) → Tensor
Parameters:

input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor


input :需要索引的 tensor
dim : 指索引的维度 (0代表横向 1代表纵向 以此类推)
index: 索引的下标

接下来直接上例子解释

import torch

b = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))

输出:

tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 3.]])


第一个式子 dim=1:torch.gather(b, dim=1, index=index_1)
input : b =

1,2,3
4,5,6


dim = 1 :代表的是维度1也就是列
index =

0,1
2,0


了解了输入后我们分步进行解析

index 的指就是代表对应维度,这里dim=1 ,0就代表第0列,1就代表第一列,2就代表第二列,我们先把每一个输出的值在input中的坐标的列写出来,注意一点,输出的shape也就是index的shape

(,0),(,1)
(,2),(,0)

这样我们就完成了每个输出所在input中的坐标的列的定位

接下来每个输出的定位横坐标。每个输出的横坐标,也就是所在输出的横坐标

(0,0),(0,1)
(1,2),(1,0)

最后用我们上面得到的坐标去获取input中对应的值

1,2
6,4

第二个式子 dim=0 :torch.gather(b, dim=0, index=index_2)
input : b =

1,2,3
4,5,6


dim = 0 :代表的是维度0也就是行
index =

0,1,1
0,0,0

有了上个式子的经验,第一步当然是写出对应dim的坐标啦,这个式子dim=0,也就可以先写出横坐标,这里的横坐标,就是index对应的值

(0,),(1,),(1,)
(0,),(0,),(0,)

接下来写出纵坐标,纵坐标也就是输出所对应的纵坐标

(0,0),(1,1),(1,2)
(0,0),(0,1),(0,2)

最后写出对应input的值

1,5,6
1,2,3

有了上面两个式子的解释,现在可以总结出gather 的用法了
gather的用法就是index所提供要索引的dim维的位置,其余维度的位置也就是index对应的位置 ,也就是输出的坐标,把dim维的替换成index中对应的数字 。

还不理解的话,再举个官方的例子:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

index 的中每个元素的坐标为:

(0,0),(0,1)
(1,0),(1,1)

dim= 1 ,也就是把第二个维度的坐标替换成index中的值

(0,0),(0,0)
(1,1),(1,0)

最后写出对应input中的值

1,1
4,3

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值