t
o
r
c
h
.
i
n
d
e
x
_
s
e
l
e
c
t
(
)
{torch.index\_select()}
torch.index_select()函数是用来对一个张量进行选择的
该函数有四个参数:
i
n
p
u
t
s
{inputs}
inputs:输入张量
d
i
m
{dim}
dim:选择张量的维度
i
n
d
i
c
e
s
{indices}
indices:所选维度的哪一个,哪一行或哪一列
o
u
t
{out}
out:选择输出到哪一个张量
随机生成size为[3,4,6]的张量
选择第0维度的第1个
import torch
inputs = torch.randn(3,4,6)
dim = 0
indices = torch.tensor(1)
out = torch.index_select(inputs, dim, indices)
print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('--------------------------------------------------------------------')
输出:
--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
tensor([[[-0.8306, 2.2510, -0.0079, 1.3343, 1.9333, -0.1910],
[ 1.0616, -0.1115, 0.1922, 0.8713, 0.8423, 0.8028],
[-0.5913, -0.5085, -0.6257, 0.6788, -0.6649, 1.8055],
[ 1.4420, -0.4334, -0.2076, -1.1872, -2.1117, 0.0666]],
[[-1.4044, 1.1257, -0.8435, -1.5138, 0.2528, -0.1170],
[ 0.5917, 1.5663, -0.1840, -0.7563, 0.0655, -3.3442],
[-0.6688, 0.7822, 0.1355, 0.0038, 2.8646, 0.4860],
[ 1.3875, -1.0201, 0.8816, 1.9196, 2.2424, -0.7823]],
[[-0.4744, 0.5082, -0.6402, 0.5121, -1.1556, 0.3753],
[-0.0148, -0.3519, 1.1065, -0.0597, -0.4442, -0.9122],
[ 0.2082, 0.2556, 0.9925, 0.5660, 1.6256, 0.4254],
[ 0.0187, 0.5005, 0.0827, -0.3486, -0.1992, -0.2559]]])
--------------------------------------------------------------------
out大小: torch.Size([1, 4, 6])
输出out:
tensor([[[-1.4044, 1.1257, -0.8435, -1.5138, 0.2528, -0.1170],
[ 0.5917, 1.5663, -0.1840, -0.7563, 0.0655, -3.3442],
[-0.6688, 0.7822, 0.1355, 0.0038, 2.8646, 0.4860],
[ 1.3875, -1.0201, 0.8816, 1.9196, 2.2424, -0.7823]]])
--------------------------------------------------------------------
选择第1维度的第0行
import torch
inputs = torch.randn(3,4,6)
dim = 1
indices = torch.tensor(0)
out = torch.index_select(inputs, dim, indices)
print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('--------------------------------------------------------------------')
输出:
--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
tensor([[[ 0.6477, -0.3442, -1.1005, 1.1959, 1.6318, -0.7133],
[-0.5315, 0.1724, -0.9170, 2.7104, -0.6939, 0.6186],
[-0.5354, -0.4394, -0.6414, -0.5140, 2.7027, -2.8920],
[-0.3202, 1.2792, -0.0243, 0.6719, 0.7145, -0.3562]],
[[-0.4330, 0.6837, -0.5198, 0.9839, 0.4061, -2.4644],
[ 0.3609, 0.5051, -1.1950, 0.8222, 0.7973, 0.5229],
[ 0.6804, 0.4372, 0.3156, -0.4217, -0.0498, -0.9208],
[-1.6734, 1.8118, 0.5074, 0.2681, 1.1954, -0.1488]],
[[ 0.2589, -0.3192, 1.6133, 0.4330, -0.6836, 0.6655],
[ 1.8668, -1.2365, 1.1614, -0.4567, 0.2217, -2.3281],
[ 1.2324, 1.3076, 1.1484, 0.8137, -0.9970, 0.0976],
[ 1.9913, -0.6387, 1.8925, -0.6371, -0.7899, -0.5704]]])
--------------------------------------------------------------------
out大小: torch.Size([3, 1, 6])
输出out:
tensor([[[ 0.6477, -0.3442, -1.1005, 1.1959, 1.6318, -0.7133]],
[[-0.4330, 0.6837, -0.5198, 0.9839, 0.4061, -2.4644]],
[[ 0.2589, -0.3192, 1.6133, 0.4330, -0.6836, 0.6655]]])
--------------------------------------------------------------------
选择第2维度的第3列
import torch
inputs = torch.randn(3,4,6)
dim = 2
indices = torch.tensor(3)
v = torch.zeros(0)
out = torch.index_select(inputs, dim, indices,out=v)
print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('张量v:\n', v)
print('--------------------------------------------------------------------')
输出:
--------------------------------------------------------------------
--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
tensor([[[ 0.9567, 0.0422, -1.5812, -0.0395, -0.0459, -1.0065],
[ 0.5025, 0.4947, -0.1656, 1.4888, -0.1460, -1.5080],
[ 0.1737, 0.2020, -0.4196, 0.4830, 0.7096, -0.3554],
[-1.8761, -0.3909, 2.2152, -2.3158, 1.4835, 0.0928]],
[[-0.7981, -0.9260, 1.0558, -0.1553, -0.3443, 2.0315],
[-0.3727, -0.9843, 0.1417, -0.6195, -0.6987, -0.9153],
[-0.5600, -0.9127, 0.0596, 1.4890, -1.3679, -0.2440],
[-0.6612, 0.2109, -0.9788, 0.5162, 1.6529, 0.7959]],
[[ 1.2576, 0.3301, 0.3410, 0.5977, -2.5256, 0.2961],
[-2.1188, -1.6647, 0.8224, 0.3663, 0.0120, -0.7276],
[-1.7476, -0.1467, 1.1624, 1.5991, 0.9478, -1.2871],
[-1.3899, -1.0566, -0.0727, 1.4246, -0.0234, 0.8622]]])
--------------------------------------------------------------------
out大小: torch.Size([3, 4, 1])
输出out:
tensor([[[-0.0395],
[ 1.4888],
[ 0.4830],
[-2.3158]],
[[-0.1553],
[-0.6195],
[ 1.4890],
[ 0.5162]],
[[ 0.5977],
[ 0.3663],
[ 1.5991],
[ 1.4246]]])
张量v:
tensor([[[-0.0395],
[ 1.4888],
[ 0.4830],
[-2.3158]],
[[-0.1553],
[-0.6195],
[ 1.4890],
[ 0.5162]],
[[ 0.5977],
[ 0.3663],
[ 1.5991],
[ 1.4246]]])
--------------------------------------------------------------------