import torch
a = torch.range(1,30)
print(a)
b = a.view(2,3,5)
print(b)
print(b.view(b.size(0),-1))
print(b.view(b.size(1),-1))
print(b.view(b.size(2),-1))
获得:
a:tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
29., 30.])
b:tensor([[[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.]],
[[16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25.],
[26., 27., 28., 29., 30.]]])
b.size(0):
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15.],
[16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
30.]])
b.size(1):
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27., 28., 29., 30.]])
b.size(2):
tensor([[ 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12.],
[13., 14., 15., 16., 17., 18.],
[19., 20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29., 30.]])
b是一个2组3行5列,
b.size(0)就是留下2组,后面3行5列拉直成15个数,行成2行15列;
b.size(1)就是留下3行,2组5列拉成10个数,行成3行10列;
b.size(2)就是留下5列,2组3行拉成6个数,行成5行6列