pytorch二维三维数据下的gather函数

gather函数真的是让人头大,看了一会,趁着还有些印象赶紧记下来。
        首先说明一下,gather函数的inputindex必须有相同的维度,例如你输入的数据是二维的,那么index也必须是二维的,但他们的shape可以不同,最终输出的结果与输入index相同。接下来首先通过二维张量举例。

import torch

input = torch.arange(16).view(4, 4)
"""
[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
"""

        接下来创建一个index,当然这个index也要是二维的。

index = torch.LongTensor([[0, 2, 2, 3]])

        首先讨论当dim=0的情况

output = input.gather(dim=0, index)
"""

[[0, 9, 10, 15]]
"""

        当dim=0的时候,可以看作按照row来选择,其中index里的数值则表示选择了哪一行,举个栗子,原index的数值为[0.2.2.3],那么表示这四个数分别来自第0行、第2行、第2行、第3行,好了,那接下来行确定了,具体选择哪个数字呢,那就要根据原index中数字所在的位置确定了,比如说[0.2.2.3]中0的索引就是0,第一个2的索引为1,第二个2索引的位置为2,3的索引为3,将这两部分拼凑在一起吧,于是我们得到了最终想要的结果。

        紧接着dim=1的情况也好说明了吧

index = torch.LongTensor([[0, 2, 2, 3]])

output = input.gather(dim=1, index)
"""

[[0 6 10 15]]
"""

        同理,只要我们将index里面的数据按照列来选择就好了,列确定之后,具体的数值依旧是按照原index里面的索引位置确定。

        当二维的我们知道怎么回事了,现在该讨论三维数据下的情况了。老规矩,先建个三维张量吧。当然我们的index也要变成三维的了。

c=torch.arange(24).view(3,2,4)

'''
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23]]])
'''

index=tensor([[[0, 0, 0, 0],
               [1, 1, 1, 1]],

               [[1, 1, 1, 1],
               [2, 2, 2, 2]],

               [[2, 2, 2, 2],
               [2, 2, 2, 2]]])

        当数据变成三维的时候,这个时候多出了一个维度的数据,这时候该怎么理解呢。

dim=0, 从一个batch中选择数据。
dim=1, 从一个batch中选择某些行。
dim=2,从一个batch中选择某些列。

        举个栗子。当dim=0时,里面的数字代表了batch,所以说我们先看最上面的一行张量[0,0,0,0],这句话代表,这行的四个数字全部来自input的第0个batch块,也就是input中0~7的这部分,那具体的选择哪个数字呢,还是老规矩,确定每个数字的索引值就好啦,[0,0,0,0]在index其所在batch块中的索引值为[0,0] 、[0,1]、[0,2]、[0,3],我们再将我们所获取的信息拼接起来,就得到了我们想要的结果了。听起来有点绕,但是我们只要根据dim所给的值确定对应的维度,再根据位置获取对应元素,理解起来就容易多了。

 

c.gather(dim=0,index=index)
'''
c = tensor([[[ 0,  1,  2,  3],
             [ 4,  5,  6,  7]],

            [[ 8,  9, 10, 11],
             [12, 13, 14, 15]],

            [[16, 17, 18, 19],
             [20, 21, 22, 23]]])

index=tensor([[[0, 0, 0, 0],
             [1, 1, 1, 1]],

            [[1, 1, 1, 1],
             [2, 2, 2, 2]],

            [[2, 2, 2, 2],
             [2, 2, 2, 2]]])

output=tensor([[[ 0,  1,  2,  3],
             [12, 13, 14, 15]],

            [[ 8,  9, 10, 11],
             [20, 21, 22, 23]],

            [[16, 17, 18, 19],
             [20, 21, 22, 23]]])

'''

再来个dim=1的例子:

 input=tensor([[[ 0,  1,  2,  3],
             [ 4,  5,  6,  7]],

            [[ 8,  9, 10, 11],
             [12, 13, 14, 15]],

            [[16, 17, 18, 19],
             [20, 21, 22, 23]]])
index=tensor([[[1, 0, 0, 1],
             [0, 0, 0, 0]],

            [[1, 1, 1, 1],
             [1, 1, 1, 1]],

            [[0, 0, 0, 0],
             [1, 1, 1, 1]]])
output=tensor([[[ 4,  1,  2,  7],
             [ 0,  1,  2,  3]],

            [[12, 13, 14, 15],
             [12, 13, 14, 15]],

            [[16, 17, 18, 19],
             [20, 21, 22, 23]]])

        有了上面的例子,这次就应该更好理解了一些,既然这次时dim=1,那么就说明batch块是确定的了,这不就由回到二维的情况了嘛。

        也不差dim=2的例子了。

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[[1, 0, 0, 1],
         [0, 0, 0, 0]],

        [[0, 1, 2, 3],
         [1, 1, 1, 1]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1]]])
tensor([[[ 1,  0,  0,  1],
         [ 4,  4,  4,  4]],

        [[ 8,  9, 10, 11],
         [13, 13, 13, 13]],

        [[16, 16, 16, 16],
         [21, 21, 21, 21]]])

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值