torch.gather()详解

torch.gather()

定义:从原tensor中获取指定dim和指定index的数据

看到这个核心定义,我们很容易想到gather()基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是 高度自定义化的,可乱序的

官方文档

在这里插入图片描述

根据官方文档的显示是根据给出的index的索引坐标来确定要寻找的坐标,然后根据dim来确定将哪个位置坐标换位index中的数字,剩余维度的位置坐标保持不变。 并且input的shape应该是和index的shape保持一致。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

到这里再回去看官方文档是不是就能看懂了!!!

【PyTorch】Torch.gather()用法详细图文解释

实战

import torch
import torch.nn.functional as F

a = torch.arange(0,16).view(4,4)
print(a)
index = torch.tensor([[0,1,2,3]]) 
print(index)
print(a.gather(0, index))
print(a.gather(1, index))
index = torch.tensor([[3, 2, 1, 0]])
print(index)
tensor_1 = a.gather(0, index)
print(tensor_1)

tesnor_2 = a.gather(1,index)
print(tesnor_2)

输出:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[0, 1, 2, 3]])
tensor([[ 0,  5, 10, 15]])
tensor([[0, 1, 2, 3]])
tensor([[3, 2, 1, 0]])
tensor([[12,  9,  6,  3]])
tensor([[3, 2, 1, 0]])

三维实例:

import torch
random_seed = 200
torch.manual_seed(random_seed)
input = torch.randint(0, 100, (2, 3, 4))
print("input:")
print(input)

index = torch.randint(0, 2, (2, 1, 2))
print("index:")
print(index)

output = input.gather(0, index)
print("output:")
print(output)

# 控制台输出
input:
tensor([[[62, 29, 76, 60],
         [82, 27, 88, 11],
         [57, 50, 71,  9]],

        [[33, 71, 66, 34],
         [20, 81,  3, 39],
         [15, 33, 19, 89]]])
index:
tensor([[[0, 1]],

        [[1, 0]]])
output:
tensor([[[62, 71]],

        [[33, 29]]])

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值