x = torch.ones(2,2,2,2)
F = torch.nn.Flatten()
y =F(x)print(y)print(y.shape)>>tensor([[1.,1.,1.,1.,1.,1.,1.,1.],[1.,1.,1.,1.,1.,1.,1.,1.]])>>torch.Size([2,8])
x = torch.ones(2,2,2,2)
F = torch.nn.Flatten(2)
y =F(x)print(y)print(y.shape)>>tensor([[[1.,1.,1.,1.],[1.,1.,1.,1.]],[[1.,1.,1.,1.],[1.,1.,1.,1.]]])>>torch.Size([2,2,4])
x = torch.ones(2,2,2,2)
F = torch.nn.Flatten(1,2)
y =F(x)print(y)print(y.shape)>>tensor([[[1.,1.],[1.,1.],[1.,1.],[1.,1.]],[[1.,1.],[1.,1.],[1.,1.],[1.,1.]]])>>torch.Size([2,4,2])
t = torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])print(t.shape)>>torch.Size([2,2,2])print(torch.flatten(t))>>tensor([1,2,3,4,5,6,7,8])print(torch.flatten(t,1))>>tensor([[1,2,3,4],[5,6,7,8]])print(torch.flatten(t,0,1).shape)>>torch.Size([4,2])
t = torch.tensor(1)print("before flatten:")print(t)print(t.shape)>>before flatten:tensor(1)
torch.Size([])print("\n")print("after flatten:")print(torch.flatten(t))print(torch.flatten(t).shape)>>after flatten:tensor([1])
torch.Size([1])