一、权值初始化
1. 梯度消失与爆炸
E ( X Y ) = E ( X ) E ( Y ) E(XY)=E(X)E(Y) E(XY)=E(X)E(Y)
D ( X ) = E ( X 2 ) − E ( X ) 2 D(X)=E(X^2)-E(X)^2 D(X)=E(X2)−E(X)2
D ( X + Y ) = D ( X ) + D ( Y ) D(X+Y)=D(X)+D(Y) D(X+Y)=D(X)+D(Y)
→ D ( X Y ) = D ( X ) D ( Y ) + D ( X ) E ( Y ) 2 + D ( Y ) E ( X ) 2 = D ( X ) D ( Y ) \rightarrow D(XY) = D(X)D(Y)+D(X)E(Y)^2+D(Y)E(X)^2=D(X)D(Y) →D(XY)=D(X)D(Y)+D(X)E(Y)2+D(Y)E(X)2=D(X)D(Y)
H 11 = ∑ i = 0 n X i ∗ W 1 i H_{11}=\sum_{i=0}^n X_i*W_{1i} H11=∑i=0nXi∗W1i
→ D ( H 11 ) = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) = n ∗ 1 ∗ 1 = n \rightarrow D(H_{11})=\sum_{i=0}^n D(X_i)*D(W_{1i})=n*1*1=n →D(H11)=∑i=0nD(Xi)∗D(W1i)=n∗1∗1=n
s t d ( H 11 = n ) std(H_{11}=\sqrt n) std(H11=n)
若仍使 D ( H 1 ) = n D ( X ) D ( W ) = 1 D(H_1)=nD(X)D(W)=1 D(H1)=nD(X)D(W)=1
→ D ( W ) = 1 n \rightarrow D(W)=\frac{1}{n} →D(W)=n1
import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seed
set_seed(3) # 设置随机种子
class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
# x = torch.relu(x)
print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, std=1)
layer_nums = 100
neural_nums = 256
batch_size = 16
net = MLP(neural_nums, layer_nums)
net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
将W设置为0均值,1标准差的标准正太分布,出现如下所示梯度爆炸现象。正如所料,每层std增加的倍数大概为 256 = 16 \sqrt{256}=16 256=16 ,将W的标准差设为 np.sqrt(1/self.neural_num)
则正常。
layer:0, std:16.0981502532959
layer:1, std:253.29345703125
layer:2, std:3982.99951171875
...
layer:30, std:2.2885405881461517e+37
layer:31, std:nan
output is nan in 31 layers
tensor([[ 4.9907e+37, -inf, inf, ..., inf,
-inf, inf],
[ -inf, inf, 2.1733e+38, ..., 9.1766e+37,
-4.5777e+37, 3.3680e+37],
[ 1.4215e+38, -inf, inf, ..., -inf,
inf, inf],
...,
[-9.2355e+37, -9.9121e+37, -3.7809e+37, ..., 4.6074e+37,
2.2305e+36, 1.2982e+38],
[ -inf, inf, -inf, ..., -inf,
-2.2394e+38, 2.0295e+36],
[ -inf, inf, 2.1518e+38, ..., -inf,
1.6132e+38, -inf]], grad_fn=<MmBackward>)
2. 激活函数初始化
对于不同的激活函数,对W的标准差初始化也不同,保持数据尺度维持在恰当范围,通常方差为1
-
Sigmoid, tanh------Xavier
-
Relu------Kaiming
D ( W ) = 2 n i D(W)=\frac{2}{n_i} D(W)=ni2
D ( W ) = 2 ( 1 + a 2 )