torch.unbind(input, dim=0)
函数的作用:
- 删除一个 tensor 的维度。可以理解为降维。
- 返回一个沿给定维度的所有切片的元组。
参数:
torch.unbind(input, dim=0):
input 为输入的 tensor。
dim 为需要移除的维度。
举例:
例子1: dim = 0
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,0))
输出:
(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))
例子2: dim = 1
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,1))
输出:
(tensor([1., 4., 7.]), tensor([2., 5., 8.]), tensor([3., 6., 9.]))
也可使用如下形式,与上面等价:
a.unbind(0)
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a.unbind(0))
输出:
(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))
参考: pytorch 官方文档
https://pytorch.org/docs/stable/generated/torch.unbind.html?highlight=unbind#torch.unbind