【Torch API】pytorch 中torch.narrow()、torch.unbind()函数详解

文章介绍了PyTorch中用于张量切片的两个函数:torch.narrow()用于在指定维度上获取一个范围的切片,而torch.unbind()则将张量沿特定维度拆分为多个单独的张量并返回一个元组。这两个函数提供了灵活的数据处理方式。
摘要由CSDN通过智能技术生成

torch.narrow()
PyTorch 中的narrow()函数起到了筛选一定维度上的数据作用。个人感觉与x[begin:end] 相同!

参考官网:torch.narrow()

用法:torch.narrow(input, dim, start, length) → Tensor

返回输入张量的切片操作结果。 输入tensor和返回的tensor共享内存。

参数说明:

  • input (Tensor) – 需切片的张量
  • dim (int) – 切片维度
  • start (int) – 开始的索引
  • length (int) – 切片长度

示例代码:

In [1]: import torch

In [2]: x = torch.randn(3,3)

In [3]: x
Out[3]:
tensor([[ 1.2474,  0.1820, -0.0179],
        [ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [4]: x.narrow(0, 1, 2) # 行切片
Out[4]:
tensor([[ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [5]: x.narrow(1, 1, 2) # 列切片
Out[5]:
tensor([[ 0.1820, -0.0179],
        [-1.7373,  0.5934],
        [ 1.1102,  0.6743]])


torch.unbind()

torch.unbind()移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。

参考官网:torch.unbind()

用法:torch.unbind(input, dim=0) → seq

返回指定维度切片后的元组。

代码示例:

In [6]: x
Out[6]:
tensor([[ 1.2474,  0.1820, -0.0179],
        [ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [7]: torch.unbind(x, 0)
Out[7]:
(tensor([ 1.2474,  0.1820, -0.0179]),
 tensor([ 0.1388, -1.7373,  0.5934]),
 tensor([0.2288, 1.1102, 0.6743]))

In [8]: torch.unbind(x, 1)
Out[8]:
(tensor([1.2474, 0.1388, 0.2288]),
 tensor([ 0.1820, -1.7373,  1.1102]),
 tensor([-0.0179,  0.5934,  0.6743]))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值