取指定维度的数据
应用
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2) # 第0维度的 0~2
tensor([[ 1, 2, 3],
[ 4, 5, 6]])
>>> torch.narrow(x, 1, 1, 2) # 第1维度的 1~2
tensor([[ 2, 3],
[ 5, 6],
[ 8, 9]])
API
torch.narrow(input, dim, start, length) → Tensor
参数 | 描述 |
---|---|
input (Tensor) | the tensor to narrow |
dim (int) | the dimension along which to narrow |
start (int) | the starting dimension |
length (int) | the distance to the ending dimension |
参考:
https://pytorch.org/docs/stable/generated/torch.narrow.html#torch.narrow
https://blog.csdn.net/u011961856/article/details/78696146