(1) 默认的dim不同,torch.flatten()默认的dim=0,而nn.Flatten()默认的dim=1,例如输入数据的尺寸是[3,1,4,4],经过torch.flatten()展开后的尺寸变为[48],而经过nn.Flatten()后得到的结果是[3, 16];
(2) nn.Flatten是一个类,而torch.flatten()则是一个函数。
import torch
from torch import nn
x =torch.randn(6,3,4,5)
y = nn.Flatten().forward(x)
print(y.size())
torch.flatten(x).shape