pytorch 笔记:index_select

1 基本使用方法

index_select 是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值

torch.index_select(input, dim, index, out=None) -> Tensor
input从中选择数据的源张量
dim从中选择数据的维度
index

一个 1D 张量,包含你想要从 dim 维度中选择的索引

此张量应该是 LongTensor 类型

out

一个可选的参数,用于指定输出张量。

如果没有提供,将创建一个新的张量。

2 举例

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11],
        [13, 15]], dtype=torch.int32)
'''

3 index_select保存梯度

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], requires_grad=True)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4.,  5.,  6.,  7.],
        [12., 13., 14., 15.]], grad_fn=<IndexSelectBackward0>)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1.,  3.],
        [ 5.,  7.],
        [ 9., 11.],
        [13., 15.]], grad_fn=<IndexSelectBackward0>)
'''

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值