torch.index_select()——数组索引

torch.index_select()——数组索引

torch.index_select(input, dim, index, *, out=None) → Tensor

功能:选择根据给定的indexdiminput中选择张量数据,相当于更高级的索引功能。

输入:

  • input:需要索引的张量数组
  • dim:索引维度(沿dim维度进行索引)
  • index:索引值,可以是单个数字、也可以是一个序列(一维序列)

注意:

  • 返回的张量数组与原始的张量数组具有相同的维数,这里与直接进行索引有区别,具体见案例代码;
  • dim维度的尺寸大小与index的长度相同,其他尺寸大小与原始张量中的尺寸相同
  • index:维数必须小于等于1

案例代码

当index索引为单个数字时

import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1])
select_1=torch.index_select(a,dim=0,index=index)
select_2=torch.index_select(a,dim=1,index=index)
select_3=torch.index_select(a,dim=2,index=index)
print(select_1)
print(select_2)
print(select_3)
print(a.shape)
print(select_1.shape)
print(select_2.shape)
print(select_3.shape)

输出

# dim=0
tensor([[[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])
# dim=1
tensor([[[ 5,  6,  7,  8,  9]],

        [[25, 26, 27, 28, 29]]])
# dim=2
tensor([[[ 1],
         [ 6],
         [11],
         [16]],

        [[21],
         [26],
         [31],
         [36]]])
# 索引后的数组尺寸,除了dim部分,其他和原来大小一样
# 原始
torch.Size([2, 4, 5])
# dim=0
torch.Size([1, 4, 5])
# dim=1
torch.Size([2, 1, 5])
# dim=2
torch.Size([2, 4, 1])

在这里插入图片描述

当索引为列表时

import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1,3])
select_1=torch.index_select(a,dim=1,index=index)
select_2=torch.index_select(a,dim=2,index=index)
print(select_1)
print(select_2)

输出

# dim=1
tensor([[[ 5,  6,  7,  8,  9],
         [15, 16, 17, 18, 19]],

        [[25, 26, 27, 28, 29],
         [35, 36, 37, 38, 39]]])
# dim=2
tensor([[[ 1,  3],
         [ 6,  8],
         [11, 13],
         [16, 18]],

        [[21, 23],
         [26, 28],
         [31, 33],
         [36, 38]]])

在这里插入图片描述

高维数组的选择

import torch
a=torch.arange(4*512*28*28).view(4,512,28,28)
index=np.random.choice(32,5)# 在0到31内随机选5个值
select=torch.index_select(a,1,index=torch.tensor(index,dtype=int))
print(index)
print(a.shape)
print(select.shape)

输出

# 索引序列
[ 1  4 28  3  3]
# 原始数组
torch.Size([4, 512, 28, 28])
# 索引后的数组
torch.Size([4, 5, 28, 28])

dim=1,相当于在第二个维度中进行选择

torch.index_select与直接索引的区别

import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1])
select_1=torch.index_select(a,dim=0,index=index)
select_2=a[0,:,:]
print(select_1)
print(select_2)
print(select_1.shape)
print(select_2.shape)

输出

# 利用torch.index_select,结果
tensor([[[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])
# 直接进行索引,结果
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
# 利用torch.index_select,形状
torch.Size([1, 4, 5])
# 直接进行索引,形状
torch.Size([4, 5])

容易发现,直接进行索引和利用torch.index_select索引最大的区别就在于:直接进行索引数组维数会降低,利用torch.index_select索引数组维数不变,下面的案例更容易理解区别

import torch
a=torch.arange(40).view(2,4,5)
select_1=a[0,:,:]
select_2=a[0,0,:]
select_3=a[0,0,0]
print(select_1)
print(select_2)
print(select_3)
print(select_1.shape)
print(select_2.shape)
print(select_3.shape)

输出

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
tensor([0, 1, 2, 3, 4])
tensor(0)
torch.Size([4, 5])
torch.Size([5])
torch.Size([])

直接索引就是一个逐步逼近的过程,随着给定的数字越多(从给1个到给3个),维数越小(结果从2维到0维),结果范围越精确。而torch.index_select只能沿着一个维度进行搜索查找,相当于对整个数组进行索引。

官方文档

torch.index_select():https://pytorch.org/docs/stable/generated/torch.index_select.html?highlight=index_select

点个赞支持一下吧

  • 24
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值