torch.index_select()——数组索引
torch.index_select(input, dim, index, *, out=None) → Tensor
功能:选择根据给定的index
和dim
在input
中选择张量数据,相当于更高级的索引功能。
输入:
input
:需要索引的张量数组dim
:索引维度(沿dim
维度进行索引)index
:索引值,可以是单个数字、也可以是一个序列(一维序列)
注意:
- 返回的张量数组与原始的张量数组具有相同的维数,这里与直接进行索引有区别,具体见案例代码;
dim
维度的尺寸大小与index
的长度相同,其他尺寸大小与原始张量中的尺寸相同index
:维数必须小于等于1
案例代码
当index索引为单个数字时
import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1])
select_1=torch.index_select(a,dim=0,index=index)
select_2=torch.index_select(a,dim=1,index=index)
select_3=torch.index_select(a,dim=2,index=index)
print(select_1)
print(select_2)
print(select_3)
print(a.shape)
print(select_1.shape)
print(select_2.shape)
print(select_3.shape)
# dim=0
tensor([[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]]])
# dim=1
tensor([[[ 5, 6, 7, 8, 9]],
[[25, 26, 27, 28, 29]]])
# dim=2
tensor([[[ 1],
[ 6],
[11],
[16]],
[[21],
[26],
[31],
[36]]])
# 索引后的数组尺寸,除了dim部分,其他和原来大小一样
# 原始
torch.Size([2, 4, 5])
# dim=0
torch.Size([1, 4, 5])
# dim=1
torch.Size([2, 1, 5])
# dim=2
torch.Size([2, 4, 1])
当索引为列表时
import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1,3])
select_1=torch.index_select(a,dim=1,index=index)
select_2=torch.index_select(a,dim=2,index=index)
print(select_1)
print(select_2)
# dim=1
tensor([[[ 5, 6, 7, 8, 9],
[15, 16, 17, 18, 19]],
[[25, 26, 27, 28, 29],
[35, 36, 37, 38, 39]]])
# dim=2
tensor([[[ 1, 3],
[ 6, 8],
[11, 13],
[16, 18]],
[[21, 23],
[26, 28],
[31, 33],
[36, 38]]])
高维数组的选择
import torch
a=torch.arange(4*512*28*28).view(4,512,28,28)
index=np.random.choice(32,5)# 在0到31内随机选5个值
select=torch.index_select(a,1,index=torch.tensor(index,dtype=int))
print(index)
print(a.shape)
print(select.shape)
# 索引序列
[ 1 4 28 3 3]
# 原始数组
torch.Size([4, 512, 28, 28])
# 索引后的数组
torch.Size([4, 5, 28, 28])
torch.index_select与直接索引的区别
import torch
a=torch.arange(40).view(2,4,5)
index=torch.tensor([1])
select_1=torch.index_select(a,dim=0,index=index)
select_2=a[0,:,:]
print(select_1)
print(select_2)
print(select_1.shape)
print(select_2.shape)
# 利用torch.index_select,结果
tensor([[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]]])
# 直接进行索引,结果
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
# 利用torch.index_select,形状
torch.Size([1, 4, 5])
# 直接进行索引,形状
torch.Size([4, 5])
容易发现,直接进行索引和利用torch.index_select索引最大的区别就在于:直接进行索引数组维数会降低,利用torch.index_select索引数组维数不变,下面的案例更容易理解区别
import torch
a=torch.arange(40).view(2,4,5)
select_1=a[0,:,:]
select_2=a[0,0,:]
select_3=a[0,0,0]
print(select_1)
print(select_2)
print(select_3)
print(select_1.shape)
print(select_2.shape)
print(select_3.shape)
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
tensor([0, 1, 2, 3, 4])
tensor(0)
torch.Size([4, 5])
torch.Size([5])
torch.Size([])
直接索引就是一个逐步逼近的过程,随着给定的数字越多(从给1个到给3个),维数越小(结果从2维到0维),结果范围越精确。而torch.index_select只能沿着一个维度进行搜索查找,相当于对整个数组进行索引。
官方文档
torch.index_select():torch.index_select — PyTorch 1.13 documentation