pytorch中x.view()用法
在pytorch中经常会看到x.view(),它表示将Tensor的维度转变为view指定的维度,有点类似于resize函数
b=torch.Tensor([[[[1,2,3],[4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]])
print(b.size())
(1, 2, 3, 3)
print(b.view(b.size(0),-1))
tensor([[1., 2., 3., 4., 5., 6., 7., 8., 9., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
print(b.view(b.size(0),-1).size())
(1, 18)
b.size(0)表示b中0维度==1,-1是按照原数据自动分配的列数。
a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.size())
(1, 2, 3)
print(a.view(6,-1))
tensor([[1.],
[2.],
[3.],
[4.],
[5.],
[6.]])
print(a.view(6,-1).size())
(6, 1)
将a转变成6行1列
print(a.view(-1,6).size())
(1, 6)
或者将a转变成1行6列
在程序里还经常见到view函数后面跟着permute()函数,这个函数是做维度换位的
print(a.view(-1,6).permute(1,0))
tensor([[1.],
[2.],
[3.],
[4.],
[5.],
[6.]])
print(a.view(-1,6).permute(1,0).size())
(6, 1)
加了permute,a就由(1,6)变成(6,1)了。