python之gather()函数详解

首先给出pytorch官方定义:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

我们只需要关注input、dim和index三个参数即可(input即被index索引的原始tensor,dim即index中的元素在input的下标中占那个位置,例如有索引a[i][j],当dim=0时,index中的元素占第一个位置,即i的位置。index当然就是input的索引啰。)

然后贴我自己摸索的代码,能看懂的请直接划走!

import torch

a = torch.arange(15).view(3, 5)
print("a:",a,a.shape,a[2][0])

"""
a: tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]]) torch.Size([3, 5])
"""

b = torch.zeros_like(a)
# print("b:",b)

b[1][2] = 1
b[0][0] = 2
print("b:",b)

"""
b: tensor([[2, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0]])
"""

c = a.gather(0, b) # dim=0
print("c:",c)

"""
c: tensor([[10,  1,  2,  3,  4],
        [ 0,  1,  7,  3,  4],
        [ 0,  1,  2,  3,  4]])
dim = 0时,
c = {
    [a[2][0],a[0][1],a[0][2],a[0][3],a[0][4]],
    [a[0][0],a[0][1],a[1][2],a[0][3],a[0][4]],
    [a[0][0],a[0][1],a[0][2],a[0][3],a[0][4]]
    }
"""

d = a.gather(1, b) # dim=0
print("d:",d)

"""
d: tensor([[ 2,  0,  0,  0,  0],
        [ 5,  5,  6,  5,  5],
        [10, 10, 10, 10, 10]])
"""

别装了,你看不懂的,快看下面的内容:

c = a.gather(0, b) # dim=0
print("c:",c)

"""
c: tensor([[10,  1,  2,  3,  4],
        [ 0,  1,  7,  3,  4],
        [ 0,  1,  2,  3,  4]])
dim = 0时,
c = {
    [a[2][0],a[0][1],a[0][2],a[0][3],a[0][4]],
    [a[0][0],a[0][1],a[1][2],a[0][3],a[0][4]],
    [a[0][0],a[0][1],a[0][2],a[0][3],a[0][4]]
    }
"""

按照我的理解(不一定对哈),dim=0时,张量b中的元素填入a的对应位置的第0维索引,如下图:
在图片中上图中橙色箭头指向的即是索引与输入的对应关系,剩余的请自己摸索。

  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值