2、torch.nn.Flatten()的参数实例
# 默认参数
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten()
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 12288])
默认将第0维保留下来,其余拍成一维
# 一个参数
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(2)
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 3, 4096])
从第二维开始,拍成一维
# 两个参数
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(1,2)
a1 = F(a)
a的大小:
torch.Size([8, 3, 64, 64])
a1的大小:
torch.Size([8, 192, 64])
将第一维到第二维拍成一维,其余不变