pytorch中的view()函数就是用来改变tensor的形状的,将多维度的Tensor展平成一维,
例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度
1.首先用代码解释x.size(0)里面的0是干嘛的
import torch
a = torch.randn(2,3)
print(a)
print(a.size())
print(a.size(0))
输出:
tensor([[-0.7595, -0.2599, -1.3692],
[ 0.8177, 1.1867, 0.0779]])
torch.Size([2, 3])
2
2.展平
x = a.view(1,-1)
print(x)
print(x.size())
输出:
tensor([[-0.7595, -0.2599, -1.3692, 0.8177, 1.1867, 0.0779]])
torch.Size([1, 6])
3.变成三维啦
x = a.view(a.size(0),1,-1)
print(x)
print(x.size())
输出:
tensor([[[-0.7595, -0.2599, -1.3692]],
[[ 0.8177, 1.1867, 0.0779]]])
torch.Size([2, 1, 3])