函数torch.gather(input, dim, index)

# torch.gather(input, dim, index)或input.gather(dim, index)
"""
--torch.gather(input, dim, index)函数接受三个参数:
input: 输入张量
dim: 操作维度
index: 索引张量(整数张量,实际上它的每个元素都为非负整数)
作用(输出):从输入张量input中抽一些元素作为输出张量的元素组成,且输出张量的shape == index.shape
关键:怎么从输入张量input中抽元素,具体抽哪一个(对每一步骤)哪些(整体上)元素。

--在讲解关键之前,先介绍几个限制条件:
1.len(input.shape) == len(index.shape)即输入张量的维数(shape的维度个数)必须与索引张量index的维数相同,即它们的shape维度个数相同(大小可以不同)。
2.在torch.gather(input, dim, index)函数中,我们指定了处理维度dim。我们限制对index.shape,除了指定维度(dim)外的其他维度,值<=输入张量input.shape
在该维度的值。实际上,index张量的每个元素都必须严格小于input.shape在指定维度(dim)上的值。
3.输出张量的shape完全由索引张量index决定,output_tensor.shape == index.shape

--为什么要加这三个限制条件?
我们下面会介绍张量元素的位置编号规则,如果不符合上述限制,会找不到要从输入张量中抽取的元素,即超出范围(越界)或者要抽取元素的位置编号在输入张量input中根本不存在。
也可以理解为不符合广播规则。

--为什么在指定处理维度(dim)上,index.shape的值可以比input.shape的值大,也可以比input.shape小(在指定处理dim维度上)?
上面说到输出张量的元素都是从输入张量中的元素抽取的,那么我们可以在输入张量中反复抽取元素,即在输出张量中可以存在两个不同位置的元素具有相同的值,此时output—tensor
与input中的元素可能是多对一关系(对应情形:在指定处理维度(dim)上,index.shape的值比input.shape的值大);同理,也可能存在我们只抽了输入张量中的一部分元素,
没有抽完(存在输入张量中的元素没有被抽到去组成输出张量),这就对应情形:在指定处理维度(dim)上,index.shape的值比input.shape的值小。

--到底怎么从输入张量中抽取元素作为输出张量的元素组成?
我们需要定义一个张量元素的位置编号规则:
例:input:
tensor([[[62, 29, 76, 60],
         [82, 27, 88, 11],
         [57, 50, 71,  9]],

        [[33, 71, 66, 34],
         [20, 81,  3, 39],
         [15, 33, 19, 89]]])
input.shape: (2, 3, 4)  <-> (batch, height, width)
我们对输入张量input中的元素定义位置编号如下(其他同理):
      ([[[000, 001, 002, 003],
         [010, 011, 012, 013],
         [020, 021, 022, 023]],

        [[100, 101, 102, 103],
         [110, 111, 112, 113],
         [120, 121, 122, 123]]])
其实不难理解,张量元素的位置编号的位数就等于张量的维数,这里输入张量input是三维张量,所以对元素的位置编号就是3位数,具体每一位数的含义就等同于(batch, height, width)
拿88举例,它在三维张量input的第一个batch张量(索引:0)的第二行(索引:1)第三列(索引:2),所以它的位置编号就是012.

--讲了这么多,我们还是不知道怎么从输入张量input中抽元素,别急,我们需要处理维度dim和索引张量index:
对index:
tensor([[[0, 1]],
        [[1, 0]]])
index.shape: (2, 1, 2) 考虑input.shape: (2, 3, 4),每一个维度(*dim)都满足index.shape[*dim] <= input.shape[*dim],所以dim可以任取0, 1, 2。总之,
取一个dim,s.t.除此维度的其他维度上都满足限制条件index.shape[*dim] <= input.shape[*dim],这里*dim != dim

--不失一般性,我们取dim=0:
上面提到输出张量的shape就等于索引张量的shape,即output_tensor.shape == index.shape,所以我们就知道输出张量就长这个样子:
output_tensor:
tensor([[[?, ?]],
        [[?, ?]]])
我们同意对索引张量中的元素进行位置编号:      [index.shape: (2, 1, 2) <-> (batch, height, width)]
      ([[[000, 001]],
        [[100, 101]]])

我们的任务就是找出每一个?到底是输入张量input中的哪个元素就好了,那么怎么找->通过我们定义的张量位置编号规则:
索引张量index:
tensor([[[0, 1]],
        [[1, 0]]])

output_tensor:
tensor([[[?, ?]],
        [[?, ?]]])
--索引张量和输出张量都有4个元素(因为shape一样),第一个元素为0,我们给定dim=0,我们考虑输出张量output_tensor的第一个元素位置编号如下:
在给定处理维度dim=0上即位置编号的第一位数是index的第一个元素即0,其他维度上(后两位数)的位置编号继承索引张量第一个元素的位置编号,后两位为00,所以输出张量的第一个元素
(在输入张量中)的位置编号就是000,对应输入张量input中的元素62;

同理,输出张量第二个元素的第一位(我们指定处理维度dim=0)位置编号为index的第二个元素值即1,后两位跟索引张量第二个元素的位置编号相同,故得位置编号101,在输入张量中对应元素值为71;

同理,输出张量的第三个元素在输入张量中的位置编号为(1 00),所以是33;

同理,输出张量的第四个元素在输入张量中的位置编号为(0 01),所以是29;

综上即得:
输出张量output_tensor:
tensor([[[62, 71]],
        [[33, 29]]])

--总结:[torch.gather(input, dim, index)或input.gather(dim, index)]
step1:对输入张量input和索引张量index的元素进行位置编号(给定input, dim, index);

step2:由index.shape创建output_tensor(因为output_tensor.shape == index.shape),之后只需找每一个元素到底是输入张量中的哪一个;

step3:确定输出张量output_tensor当前元素在输入张量中的位置编号:第一位数是index对应位置的元素值(这里因为dim=0所以是第一位数,总之是dim对应的那一位),其他维度上(其他位)与index对应位置元素的位置编号相同。

step4:遍历输出张量output_tensor所有元素,从一个到最后一个,做step3,即逐元素根据step3获得的位置编号得到要从输入张量input中抽取的元素值;
"""

--验证代码:
​​​​​​​import torch
random_seed = 200
torch.manual_seed(random_seed)
input = torch.randint(0, 100, (2, 3, 4))
print("input:")
print(input)

index = torch.randint(0, 2, (2, 1, 2))
print("index:")
print(index)

output = input.gather(0, index)
print("output:")
print(output)

# 控制台输出
input:
tensor([[[62, 29, 76, 60],
         [82, 27, 88, 11],
         [57, 50, 71,  9]],

        [[33, 71, 66, 34],
         [20, 81,  3, 39],
         [15, 33, 19, 89]]])
index:
tensor([[[0, 1]],

        [[1, 0]]])
output:
tensor([[[62, 71]],

        [[33, 29]]])


参考博客:https://blog.csdn.net/Mocode/article/details/131039356

[大家可以先看参考博客,用导购员帮顾客买糖的角度讲的,然后再看这个加深理解]

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值