参考链接: index_select(dim, index) → Tensor
参考链接: torch.index_select(input, dim, index, out=None) → Tensor
函数功能说明:
torch.index_select(input, dim, index, out=None)
根据index在给定的轴dim上,进行元素选择.
实验代码展示:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>>
>>>
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]])
>>>
>>> indices = torch.tensor([0, 2])
>>> indices
tensor([0, 2])
>>>
>>> torch.index_select(x, 0, indices)
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[ 1.1216, 0.8440, 0.1783, 0.6859]])
>>>
>>> torch.index_select(x, 1, indices)
tensor([[ 0.2824, 0.9088],
[-0.1806, 1.0406],
[ 1.1216, 0.1783]])
>>>
>>>
>>>
>>>
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>> x = torch.randn(2, 3, 4)
>>> indices
tensor([0, 2])
>>> x
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379],
[-1.1943, 0.1426, 1.3612, -1.4171],
[ 0.2797, -0.5316, 0.6480, 2.6538]],
[[ 1.1088, 1.2914, -1.4494, -1.7273],
[-0.2370, -1.7016, -0.2565, 1.4568],
[ 0.7152, 1.1784, 0.0806, 0.7787]]])
>>>
>>>
>>> # x = torch.randn(3, 4, 5)
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>>
>>> x = torch.randn(3, 4, 5)
>>> indices
tensor([0, 2])
>>> x
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379, -1.1943],
[ 0.1426, 1.3612, -1.4171, -2.7679, 2.1961],
[ 1.3462, -0.7650, 1.5528, 0.7210, -0.1314],
[ 0.1448, 1.5570, 1.0334, 0.2356, 0.6646]],
[[-0.4996, 0.2791, -0.5900, 0.2527, 0.6736],
[-3.4868, 1.0735, -0.0428, -0.2797, -1.3232],
[-0.8263, -0.4343, -1.0501, 1.5451, 2.1098],
[ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],
[[ 1.8106, -0.1812, 1.0847, 0.3907, -1.0521],
[ 0.7310, -0.0536, 0.1632, -1.2770, 0.8514],
[ 1.2402, -1.6003, -2.0225, -0.0334, 0.7078],
[ 0.2945, 0.1468, -0.5430, 0.1792, 0.0895]]])
>>> torch.index_select(x, 0, indices)
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379, -1.1943],
[ 0.1426, 1.3612, -1.4171, -2.7679, 2.1961],
[ 1.3462, -0.7650, 1.5528, 0.7210, -0.1314],
[ 0.1448, 1.5570, 1.0334, 0.2356, 0.6646]],
[[ 1.8106, -0.1812, 1.0847, 0.3907, -1.0521],
[ 0.7310, -0.0536, 0.1632, -1.2770, 0.8514],
[ 1.2402, -1.6003, -2.0225, -0.0334, 0.7078],
[ 0.2945, 0.1468, -0.5430, 0.1792, 0.0895]]])
>>>
>>> x
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379, -1.1943],
[ 0.1426, 1.3612, -1.4171, -2.7679, 2.1961],
[ 1.3462, -0.7650, 1.5528, 0.7210, -0.1314],
[ 0.1448, 1.5570, 1.0334, 0.2356, 0.6646]],
[[-0.4996, 0.2791, -0.5900, 0.2527, 0.6736],
[-3.4868, 1.0735, -0.0428, -0.2797, -1.3232],
[-0.8263, -0.4343, -1.0501, 1.5451, 2.1098],
[ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],
[[ 1.8106, -0.1812, 1.0847, 0.3907, -1.0521],
[ 0.7310, -0.0536, 0.1632, -1.2770, 0.8514],
[ 1.2402, -1.6003, -2.0225, -0.0334, 0.7078],
[ 0.2945, 0.1468, -0.5430, 0.1792, 0.0895]]])
>>> torch.index_select(x, 1, indices)
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379, -1.1943],
[ 1.3462, -0.7650, 1.5528, 0.7210, -0.1314]],
[[-0.4996, 0.2791, -0.5900, 0.2527, 0.6736],
[-0.8263, -0.4343, -1.0501, 1.5451, 2.1098]],
[[ 1.8106, -0.1812, 1.0847, 0.3907, -1.0521],
[ 1.2402, -1.6003, -2.0225, -0.0334, 0.7078]]])
>>>
>>> x
tensor([[[ 0.5816, 2.0060, 1.6013, -0.6379, -1.1943],
[ 0.1426, 1.3612, -1.4171, -2.7679, 2.1961],
[ 1.3462, -0.7650, 1.5528, 0.7210, -0.1314],
[ 0.1448, 1.5570, 1.0334, 0.2356, 0.6646]],
[[-0.4996, 0.2791, -0.5900, 0.2527, 0.6736],
[-3.4868, 1.0735, -0.0428, -0.2797, -1.3232],
[-0.8263, -0.4343, -1.0501, 1.5451, 2.1098],
[ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],
[[ 1.8106, -0.1812, 1.0847, 0.3907, -1.0521],
[ 0.7310, -0.0536, 0.1632, -1.2770, 0.8514],
[ 1.2402, -1.6003, -2.0225, -0.0334, 0.7078],
[ 0.2945, 0.1468, -0.5430, 0.1792, 0.0895]]])
>>> torch.index_select(x, 2, indices)
tensor([[[ 0.5816, 1.6013],
[ 0.1426, -1.4171],
[ 1.3462, 1.5528],
[ 0.1448, 1.0334]],
[[-0.4996, -0.5900],
[-3.4868, -0.0428],
[-0.8263, -1.0501],
[ 0.1178, -0.4440]],
[[ 1.8106, 1.0847],
[ 0.7310, 0.1632],
[ 1.2402, -2.0225],
[ 0.2945, -0.5430]]])
>>>
>>>
>>>