【python函数】torch.index_select()函数用法解析

   t o r c h . i n d e x _ s e l e c t ( ) {torch.index\_select()} torch.index_select()函数是用来对一个张量进行选择的
  该函数有四个参数:
   i n p u t s {inputs} inputs:输入张量
   d i m {dim} dim:选择张量的维度
   i n d i c e s {indices} indices:所选维度的哪一个,哪一行或哪一列
   o u t {out} out:选择输出到哪一个张量

随机生成size为[3,4,6]的张量
选择第0维度的第1个

import torch

inputs = torch.randn(3,4,6)
dim = 0
indices = torch.tensor(1)

out = torch.index_select(inputs, dim, indices)

print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('--------------------------------------------------------------------')

输出:

--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
 tensor([[[-0.8306,  2.2510, -0.0079,  1.3343,  1.9333, -0.1910],
         [ 1.0616, -0.1115,  0.1922,  0.8713,  0.8423,  0.8028],
         [-0.5913, -0.5085, -0.6257,  0.6788, -0.6649,  1.8055],
         [ 1.4420, -0.4334, -0.2076, -1.1872, -2.1117,  0.0666]],

        [[-1.4044,  1.1257, -0.8435, -1.5138,  0.2528, -0.1170],
         [ 0.5917,  1.5663, -0.1840, -0.7563,  0.0655, -3.3442],
         [-0.6688,  0.7822,  0.1355,  0.0038,  2.8646,  0.4860],
         [ 1.3875, -1.0201,  0.8816,  1.9196,  2.2424, -0.7823]],

        [[-0.4744,  0.5082, -0.6402,  0.5121, -1.1556,  0.3753],
         [-0.0148, -0.3519,  1.1065, -0.0597, -0.4442, -0.9122],
         [ 0.2082,  0.2556,  0.9925,  0.5660,  1.6256,  0.4254],
         [ 0.0187,  0.5005,  0.0827, -0.3486, -0.1992, -0.2559]]])
--------------------------------------------------------------------
out大小: torch.Size([1, 4, 6])
输出out:
 tensor([[[-1.4044,  1.1257, -0.8435, -1.5138,  0.2528, -0.1170],
         [ 0.5917,  1.5663, -0.1840, -0.7563,  0.0655, -3.3442],
         [-0.6688,  0.7822,  0.1355,  0.0038,  2.8646,  0.4860],
         [ 1.3875, -1.0201,  0.8816,  1.9196,  2.2424, -0.7823]]])
--------------------------------------------------------------------

选择第1维度的第0行

import torch

inputs = torch.randn(3,4,6)
dim = 1
indices = torch.tensor(0)

out = torch.index_select(inputs, dim, indices)

print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('--------------------------------------------------------------------')

输出:

--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
 tensor([[[ 0.6477, -0.3442, -1.1005,  1.1959,  1.6318, -0.7133],
         [-0.5315,  0.1724, -0.9170,  2.7104, -0.6939,  0.6186],
         [-0.5354, -0.4394, -0.6414, -0.5140,  2.7027, -2.8920],
         [-0.3202,  1.2792, -0.0243,  0.6719,  0.7145, -0.3562]],

        [[-0.4330,  0.6837, -0.5198,  0.9839,  0.4061, -2.4644],
         [ 0.3609,  0.5051, -1.1950,  0.8222,  0.7973,  0.5229],
         [ 0.6804,  0.4372,  0.3156, -0.4217, -0.0498, -0.9208],
         [-1.6734,  1.8118,  0.5074,  0.2681,  1.1954, -0.1488]],

        [[ 0.2589, -0.3192,  1.6133,  0.4330, -0.6836,  0.6655],
         [ 1.8668, -1.2365,  1.1614, -0.4567,  0.2217, -2.3281],
         [ 1.2324,  1.3076,  1.1484,  0.8137, -0.9970,  0.0976],
         [ 1.9913, -0.6387,  1.8925, -0.6371, -0.7899, -0.5704]]])
--------------------------------------------------------------------
out大小: torch.Size([3, 1, 6])
输出out:
 tensor([[[ 0.6477, -0.3442, -1.1005,  1.1959,  1.6318, -0.7133]],

        [[-0.4330,  0.6837, -0.5198,  0.9839,  0.4061, -2.4644]],

        [[ 0.2589, -0.3192,  1.6133,  0.4330, -0.6836,  0.6655]]])
--------------------------------------------------------------------

选择第2维度的第3列

import torch

inputs = torch.randn(3,4,6)
dim = 2
indices = torch.tensor(3)
v = torch.zeros(0)
out = torch.index_select(inputs, dim, indices,out=v)

print('--------------------------------------------------------------------')
print('inputs大小:', inputs.size())
print('输入inputs:\n', inputs)
print('--------------------------------------------------------------------')
print('out大小:', out.size())
print('输出out:\n', out)
print('张量v:\n', v)
print('--------------------------------------------------------------------')

输出:

--------------------------------------------------------------------
--------------------------------------------------------------------
inputs大小: torch.Size([3, 4, 6])
输入inputs:
 tensor([[[ 0.9567,  0.0422, -1.5812, -0.0395, -0.0459, -1.0065],
         [ 0.5025,  0.4947, -0.1656,  1.4888, -0.1460, -1.5080],
         [ 0.1737,  0.2020, -0.4196,  0.4830,  0.7096, -0.3554],
         [-1.8761, -0.3909,  2.2152, -2.3158,  1.4835,  0.0928]],

        [[-0.7981, -0.9260,  1.0558, -0.1553, -0.3443,  2.0315],
         [-0.3727, -0.9843,  0.1417, -0.6195, -0.6987, -0.9153],
         [-0.5600, -0.9127,  0.0596,  1.4890, -1.3679, -0.2440],
         [-0.6612,  0.2109, -0.9788,  0.5162,  1.6529,  0.7959]],

        [[ 1.2576,  0.3301,  0.3410,  0.5977, -2.5256,  0.2961],
         [-2.1188, -1.6647,  0.8224,  0.3663,  0.0120, -0.7276],
         [-1.7476, -0.1467,  1.1624,  1.5991,  0.9478, -1.2871],
         [-1.3899, -1.0566, -0.0727,  1.4246, -0.0234,  0.8622]]])
--------------------------------------------------------------------
out大小: torch.Size([3, 4, 1])
输出out:
 tensor([[[-0.0395],
         [ 1.4888],
         [ 0.4830],
         [-2.3158]],

        [[-0.1553],
         [-0.6195],
         [ 1.4890],
         [ 0.5162]],

        [[ 0.5977],
         [ 0.3663],
         [ 1.5991],
         [ 1.4246]]])
张量v:
 tensor([[[-0.0395],
         [ 1.4888],
         [ 0.4830],
         [-2.3158]],

        [[-0.1553],
         [-0.6195],
         [ 1.4890],
         [ 0.5162]],

        [[ 0.5977],
         [ 0.3663],
         [ 1.5991],
         [ 1.4246]]])
--------------------------------------------------------------------
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值