nn.Flatten():对图像数组进行展平操作 例子 # nn.Flatten():对图像数组进行展平操作 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) # 给线性层的权重初始化,符合正态分布,标准差为0.01 net.apply(init_weights); Pytorch——nn.Flatten()函数