torch.index_select()的使用举例

参考链接: index_select(dim, index) → Tensor
参考链接: torch.index_select(input, dim, index, out=None) → Tensor

在这里插入图片描述

函数功能说明:

torch.index_select(input, dim, index, out=None)
根据index在给定的轴dim上,进行元素选择.

实验代码展示:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>>
>>>
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]])
>>>
>>> indices = torch.tensor([0, 2])
>>> indices
tensor([0, 2])
>>>
>>> torch.index_select(x, 0, indices)
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [ 1.1216,  0.8440,  0.1783,  0.6859]])
>>>
>>> torch.index_select(x, 1, indices)
tensor([[ 0.2824,  0.9088],
        [-0.1806,  1.0406],
        [ 1.1216,  0.1783]])
>>>
>>>
>>>
>>>
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>> x = torch.randn(2, 3, 4)
>>> indices
tensor([0, 2])
>>> x
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379],
         [-1.1943,  0.1426,  1.3612, -1.4171],
         [ 0.2797, -0.5316,  0.6480,  2.6538]],

        [[ 1.1088,  1.2914, -1.4494, -1.7273],
         [-0.2370, -1.7016, -0.2565,  1.4568],
         [ 0.7152,  1.1784,  0.0806,  0.7787]]])
>>>
>>>
>>> # x = torch.randn(3, 4, 5)
>>>
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001DF1130D330>
>>>
>>> x = torch.randn(3, 4, 5)
>>> indices
tensor([0, 2])
>>> x
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379, -1.1943],
         [ 0.1426,  1.3612, -1.4171, -2.7679,  2.1961],
         [ 1.3462, -0.7650,  1.5528,  0.7210, -0.1314],
         [ 0.1448,  1.5570,  1.0334,  0.2356,  0.6646]],

        [[-0.4996,  0.2791, -0.5900,  0.2527,  0.6736],
         [-3.4868,  1.0735, -0.0428, -0.2797, -1.3232],
         [-0.8263, -0.4343, -1.0501,  1.5451,  2.1098],
         [ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],

        [[ 1.8106, -0.1812,  1.0847,  0.3907, -1.0521],
         [ 0.7310, -0.0536,  0.1632, -1.2770,  0.8514],
         [ 1.2402, -1.6003, -2.0225, -0.0334,  0.7078],
         [ 0.2945,  0.1468, -0.5430,  0.1792,  0.0895]]])
>>> torch.index_select(x, 0, indices)
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379, -1.1943],
         [ 0.1426,  1.3612, -1.4171, -2.7679,  2.1961],
         [ 1.3462, -0.7650,  1.5528,  0.7210, -0.1314],
         [ 0.1448,  1.5570,  1.0334,  0.2356,  0.6646]],

        [[ 1.8106, -0.1812,  1.0847,  0.3907, -1.0521],
         [ 0.7310, -0.0536,  0.1632, -1.2770,  0.8514],
         [ 1.2402, -1.6003, -2.0225, -0.0334,  0.7078],
         [ 0.2945,  0.1468, -0.5430,  0.1792,  0.0895]]])
>>>
>>> x
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379, -1.1943],
         [ 0.1426,  1.3612, -1.4171, -2.7679,  2.1961],
         [ 1.3462, -0.7650,  1.5528,  0.7210, -0.1314],
         [ 0.1448,  1.5570,  1.0334,  0.2356,  0.6646]],

        [[-0.4996,  0.2791, -0.5900,  0.2527,  0.6736],
         [-3.4868,  1.0735, -0.0428, -0.2797, -1.3232],
         [-0.8263, -0.4343, -1.0501,  1.5451,  2.1098],
         [ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],

        [[ 1.8106, -0.1812,  1.0847,  0.3907, -1.0521],
         [ 0.7310, -0.0536,  0.1632, -1.2770,  0.8514],
         [ 1.2402, -1.6003, -2.0225, -0.0334,  0.7078],
         [ 0.2945,  0.1468, -0.5430,  0.1792,  0.0895]]])
>>> torch.index_select(x, 1, indices)
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379, -1.1943],
         [ 1.3462, -0.7650,  1.5528,  0.7210, -0.1314]],

        [[-0.4996,  0.2791, -0.5900,  0.2527,  0.6736],
         [-0.8263, -0.4343, -1.0501,  1.5451,  2.1098]],

        [[ 1.8106, -0.1812,  1.0847,  0.3907, -1.0521],
         [ 1.2402, -1.6003, -2.0225, -0.0334,  0.7078]]])
>>>
>>> x
tensor([[[ 0.5816,  2.0060,  1.6013, -0.6379, -1.1943],
         [ 0.1426,  1.3612, -1.4171, -2.7679,  2.1961],
         [ 1.3462, -0.7650,  1.5528,  0.7210, -0.1314],
         [ 0.1448,  1.5570,  1.0334,  0.2356,  0.6646]],

        [[-0.4996,  0.2791, -0.5900,  0.2527,  0.6736],
         [-3.4868,  1.0735, -0.0428, -0.2797, -1.3232],
         [-0.8263, -0.4343, -1.0501,  1.5451,  2.1098],
         [ 0.1178, -0.4273, -0.4440, -0.4213, -0.9365]],

        [[ 1.8106, -0.1812,  1.0847,  0.3907, -1.0521],
         [ 0.7310, -0.0536,  0.1632, -1.2770,  0.8514],
         [ 1.2402, -1.6003, -2.0225, -0.0334,  0.7078],
         [ 0.2945,  0.1468, -0.5430,  0.1792,  0.0895]]])
>>> torch.index_select(x, 2, indices)
tensor([[[ 0.5816,  1.6013],
         [ 0.1426, -1.4171],
         [ 1.3462,  1.5528],
         [ 0.1448,  1.0334]],

        [[-0.4996, -0.5900],
         [-3.4868, -0.0428],
         [-0.8263, -1.0501],
         [ 0.1178, -0.4440]],

        [[ 1.8106,  1.0847],
         [ 0.7310,  0.1632],
         [ 1.2402, -2.0225],
         [ 0.2945, -0.5430]]])
>>>
>>>
>>>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值