torch.gather——沿特定维度收集数值

PyTorch学习笔记:torch.gather——沿特定维度收集数值

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

功能:从输入的数组中,沿指定的dim维度,利用索引变量index,将数据索引出来,并且堆叠成一个数组。直观可能不好理解,具体可以见代码案例。

输入:

input:输入的数组

dim:指定的维度

index:索引变量,数据类型需是长整型(int64)

注意:

  • inputindex具有相同的维数

  • outindex具有相同的形状

  • 除了dim维度,在每个维度上,索引在该维度上的大小要小于等于输入在该维度上的大小,即:
    i n d e x . s i z e ( d ) ≤ i n p u t . s i z e ( d ) , d ! = d i m index.size(d)≤input.size(d),\quad d!=dim index.size(d)input.size(d),d!=dim

代码案例

一般用法,当在一个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2,3,1,3]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

在这里插入图片描述

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2],[3],[1],[3]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

在这里插入图片描述

当同时在两个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2,3],[2,3,0],[3,0,1]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
tensor([[ 5, 11, 17],
        [10, 16,  2],
        [15,  1,  7]])

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2],[2,3],[3,4]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
tensor([[ 1,  2],
        [ 7,  8],
        [13, 14]])

官方文档

torch.gather:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值