对torch.gather的简单理解

文章介绍了如何使用PyTorch中的torch.gather函数来根据指定的索引从输入矩阵中提取值。它展示了二维矩阵的例子,解释了输入矩阵、索引矩阵和输出大小的关系,并给出了一个具体的操作示例,强调了索引在不同维度上的应用。
摘要由CSDN通过智能技术生成

看了一堆花里胡哨的,还是通过官方的定义看的明白

 输入一个矩阵 

x = torch.arange(1,7).reshape(2,3)
Out[4]: 
tensor([[1, 2, 3],
        [4, 5, 6]])

 x是二维的,两行三列

输入一个索引矩阵index(LongTensor不是Tensor),维度(dim)要和x一致

index = torch.LongTensor([[0,1,1]])
index
Out[6]: tensor([[0, 1, 1]])
index.shape
Out[7]: torch.Size([1, 3])

torch.gather的输出size和index是一致的,这里是1×3一行三列的一个矩阵

我们先写一个一行三列的矩阵三个值分别为[out00,out01,out02]

torch.gather(input,dim,index) input为输入矩阵,dim为维数,index为索引矩阵

这里根据维数和索引值去更换我需要的值,例如我dim=0,我就把这里的00,01,02的第一个0换成index输入的值

索引什么值呢,就是input的矩阵中对应的值,对应到x就是1,5,6,即gather的输出 

 同理可扩展至更高维度

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值