torch.narrow()
PyTorch 中的narrow()函数起到了筛选一定维度上的数据作用。个人感觉与x[begin:end] 相同!
参考官网:torch.narrow()
用法:torch.narrow(input, dim, start, length) → Tensor
返回输入张量的切片操作结果。 输入tensor和返回的tensor共享内存。
参数说明:
- input (Tensor) – 需切片的张量
- dim (int) – 切片维度
- start (int) – 开始的索引
- length (int) – 切片长度
示例代码:
In [1]: import torch
In [2]: x = torch.randn(3,3)
In [3]: x
Out[3]:
tensor([[ 1.2474, 0.1820, -0.0179],
[ 0.1388, -1.7373, 0.5934],
[ 0.2288, 1.1102, 0.6743]])
In [4]: x.narrow(0, 1, 2) # 行切片
Out[4]:
tensor([[ 0.1388, -1.7373, 0.5934],
[ 0.2288, 1.1102, 0.6743]])
In [5]: x.narrow(1, 1, 2) # 列切片
Out[5]:
tensor([[ 0.1820, -0.0179],
[-1.7373, 0.5934],
[ 1.1102, 0.6743]])
torch.unbind()
torch.unbind()
移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。
参考官网:torch.unbind()
用法:torch.unbind(input, dim=0) → seq
返回指定维度切片后的元组。
代码示例:
In [6]: x
Out[6]:
tensor([[ 1.2474, 0.1820, -0.0179],
[ 0.1388, -1.7373, 0.5934],
[ 0.2288, 1.1102, 0.6743]])
In [7]: torch.unbind(x, 0)
Out[7]:
(tensor([ 1.2474, 0.1820, -0.0179]),
tensor([ 0.1388, -1.7373, 0.5934]),
tensor([0.2288, 1.1102, 0.6743]))
In [8]: torch.unbind(x, 1)
Out[8]:
(tensor([1.2474, 0.1388, 0.2288]),
tensor([ 0.1820, -1.7373, 1.1102]),
tensor([-0.0179, 0.5934, 0.6743]))