一、说明
torch.init.normal_(tensor,mean=,std=) ,mean:均值,std:正态分布的标准差。
torch.init.normal_:给tensor初始化,一般是给网络中参数weight初始化,初始化参数值符合正态分布。
二、例子
# nn.Flatten():对图像数组进行展平操作
net = nn.Sequential(nn.Flatten(),
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10))
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01) # 给线性层的权重初始化,符合正态分布,标准差为0.01
net.apply(init_weights);