文章目录
1. 背景
我们在训练神经网络的过程中,需要经常对神经网络进行随机初始化。但当神经网络不复杂的时候,我们可以不用太关心参数初始化。然而对于深度神经网络来说,初始化方案对于训练的收敛性起到至关作用。糟糕的初始化会让发生梯度爆炸和梯度消失。
2. 梯度消失&梯度爆炸
假设我们有 d 层的深度神经网络, t 表示层数。
h
t
=
f
t
(
h
t
−
1
)
(1)
h^t=f_t(h^{t-1})\tag{1}
ht=ft(ht−1)(1)
y
=
l
⋅
f
d
⋅
.
.
.
⋅
f
1
(
x
)
(2)
y=l·f_d·...·f_1(x)\tag{2}
y=l⋅fd⋅...⋅f1(x)(2)
那么我们可以计算损失值 l 关于权重参数 w 的导数如下:
∂
l
∂
w
t
=
∂
l
∂
h
d
⋅
∂
h
d
∂
h
d
−
1
.
.
.
∂
h
t
+
1
∂
h
t
⏟
d
−
t
次
矩
阵
乘
法
⋅
∂
h
t
∂
w
t
(3)
\frac{\partial l}{\partial w^t}=\frac{\partial l}{\partial h^d}·\underbrace{\frac{\partial h^d}{\partial h^{d-1}}...\frac{\partial h^{t+1}}{\partial h^{t}}}_{d-t次矩阵乘法}·\frac{\partial h^t}{\partial w^t}\tag{3}
∂wt∂l=∂hd∂l⋅d−t次矩阵乘法
∂hd−1∂hd...∂ht∂ht+1⋅∂wt∂ht(3)
因为要进行 d-t 次矩阵乘法,那么如果每一个梯度值为 m 。
1.
5
100
≈
4
×
1
0
17
→
梯
度
爆
炸
(4)
1.5^{100}≈4×10^{17}\rightarrow 梯度爆炸\tag{4}
1.5100≈4×1017→梯度爆炸(4)
0.
8
100
≈
2
×
1
0
−
10
→
梯
度
消
失
(5)
0.8^{100}≈2×10^{-10}\rightarrow 梯度消失\tag{5}
0.8100≈2×10−10→梯度消失(5)
- 梯度爆炸:导致了梯度超过计算机值的范围,造成上溢
- 梯度消失:导致了梯度十分的小,导致神经网络无法更新相关参数
为了解决上述梯度爆炸和梯度消失问题,我们希望在每层梯度都在合理范围内。 - MLP多层感知机举例
假设第 t 层的函数如下:
f t ( h t − 1 ) = σ ( w t h t − 1 ) ; σ 是 激 活 函 数 (6) f_t(h^{t-1})=\sigma(w^th^{t-1});\sigma是激活函数\tag{6} ft(ht−1)=σ(wtht−1);σ是激活函数(6)
对权重 w t w_t wt求导可得:
∂ h t ∂ h t − 1 = d i a g ( σ ′ ( w t h t − 1 ) ) ( w t ) T ; σ ′ 是 σ 的 导 数 函 数 (7) \frac{\partial h^t}{\partial h^{t-1}}=diag(\sigma'(w^th^{t-1}))(w^t)^T;\sigma'是\sigma的导数函数\tag{7} ∂ht−1∂ht=diag(σ′(wtht−1))(wt)T;σ′是σ的导数函数(7)
∏ i = t d − 1 ∂ h t − 1 ∂ h i = ∏ i = t d − 1 d i a g ( σ ′ ( w t h t − 1 ) ) ( w t ) T (8) \prod_{i=t}^{d-1}\frac{\partial h^{t-1}}{\partial h^i}=\prod_{i=t}^{d-1}diag(\sigma'(w^th^{t-1}))(w^t)^T\tag{8} i=t∏d−1∂hi∂ht−1=i=t∏d−1diag(σ′(wtht−1))(wt)T(8) - 假设使用ReLU作为激活函数,那么可以得到:
σ ( x ) = m a x ( 0 , x ) (9) \sigma(x)=max(0,x)\tag{9} σ(x)=max(0,x)(9)
σ ′ ( x ) = { 1 , i f x > 0 0 ; o t h e r w i s e (10) \sigma'(x)=\left\{\begin{array}{l} 1 ,\qquad if \quad x>0 \\ 0 ;\qquad otherwise \end{array}\right.\tag{10} σ′(x)={1,ifx>00;otherwise(10)
- 注:当我们的值大于零时候, ∏ i = t d − 1 ∂ h t − 1 ∂ h i \prod_{i=t}^{d-1}\frac{\partial h^{t-1}}{\partial h^i} ∏i=td−1∂hi∂ht−1中的一些元素的值就由 ∏ i = t d − 1 ( w i ) T \prod_{i=t}^{d-1}(w^i)^T ∏i=td−1(wi)T来决定,如果 d-t很大,那么这个连乘值就非常的大,从而导致梯度爆炸或者梯度消失。
3. 模型初始化
为了解决梯度爆炸和梯度消失,我们希望对找到一个合理的区间进行初始化,常见思路如下:
- 目标:让梯度值在合理的范围内,例如[le-6,le3]
- 将乘法变加法: ResNet,LSTM
- 归一化:梯度归一化,梯度裁剪
- 合理的权重初始和激活函数 <重点>
3.1 目标
为了解决上述问题,我们希望是:让每层的方差是一个常数
- 将每层的输出和梯度都看作是随机变量
- 让它们的均值和方差都保持一致
- 正向期望和方差:
E [ h i t ] = 0 ; V a r [ h i t ] = a (11) E[h^t_i]=0;\quad Var[h_i^t]=a\tag{11} E[hit]=0;Var[hit]=a(11) - 反向期望和方差:
E [ ∂ l ∂ h i t ] = 0 ; V a r [ ∂ l ∂ h i t ] = b ; ∀ i , t (12) E[\frac{\partial l}{\partial h^t_i}]=0;\quad Var[\frac{\partial l}{\partial h^t_i}]=b;\quad \forall i,t\tag{12} E[∂hit∂l]=0;Var[∂hit∂l]=b;∀i,t(12) - 注:其中 a,b 都是常数
- 权重初始化
在合理值区间里随机初始参数
训练开始的时候更容易有数值不稳定。常见思路如下: a .远离最优解的地方损失函数表面可能很复杂;b.最优解附近表面会比较平;c.使用N(0,0.01)来初始可能对小玩过没问题,但不能保证深度神经网络 - 我们以 MLP 为例来讲解下:
假设<1>
权重服从独立同分布,即
w
i
,
j
t
w_{i,j}^t
wi,jt是 i.i.d,且假设均值为 0 ,方差为
γ
t
\gamma_t
γt,且
h
i
t
−
1
h_i^{t-1}
hit−1独立于
w
i
,
j
t
w_{i,j}^t
wi,jt
E
[
w
i
,
j
t
]
=
0
;
V
a
r
[
w
i
,
j
t
]
=
γ
t
(13)
E[w_{i,j}^t]=0;\quad Var[w_{i,j}^t]=\gamma_t\tag{13}
E[wi,jt]=0;Var[wi,jt]=γt(13)
假设<2>
假设这个 MLP 没有激活函数,满足
h
t
=
w
t
h
t
−
1
h^t=w^th^{t-1}
ht=wtht−1,这里
w
t
∈
R
n
t
×
n
t
−
1
w^t \in R^{n_t×n_{t-1}}
wt∈Rnt×nt−1
E
[
h
i
t
]
=
E
[
∑
j
w
i
,
j
t
h
j
t
−
1
]
=
∑
j
E
[
w
i
,
j
t
]
E
[
h
j
t
−
1
]
=
0
(14)
E[h_i^t]=E[\sum_jw^t_{i,j}h_j^{t-1}]=\sum_jE[w^t_{i,j}]E[h_j^{t-1}]=0\tag{14}
E[hit]=E[j∑wi,jthjt−1]=j∑E[wi,jt]E[hjt−1]=0(14)
- 正向方差计算:
V a r [ h i t ] = E [ ( h i t ) 2 ] − E [ h i t ] 2 = E [ ( h i t ) 2 ] − 0 = E [ ( ∑ j w i , j t h j t − 1 ) 2 ] (15) Var[h_i^t]=E[(h_i^t)^2]-E[h_i^t]^2=E[(h_i^t)^2]-0=E[(\sum_jw^t_{i,j}h_j^{t-1})^2]\tag{15} Var[hit]=E[(hit)2]−E[hit]2=E[(hit)2]−0=E[(j∑wi,jthjt−1)2](15)
展开上式可得:
= E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 + ∑ j ≠ k w i , j t w i , k t h j t − 1 h k t − 1 ] (16) =E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2+\sum_{j≠k}w_{i,j}^tw_{i,k}^th_j^{t-1}h_k^{t-1}]\tag{16} =E[j∑(wi,jt)2(hjt−1)2+j=k∑wi,jtwi,kthjt−1hkt−1](16)
因为 E [ h i t ] = 0 E[h_i^t]=0 E[hit]=0
= E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 ] = ∑ j E [ ( w i , j t ) 2 ] E [ ( h j t − 1 ) 2 ] (17) =E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2]=\sum_jE[(w_{i,j}^t)^2]E[(h_j^{t-1})^2]\tag{17} =E[j∑(wi,jt)2(hjt−1)2]=j∑E[(wi,jt)2]E[(hjt−1)2](17)
因为 V a r [ h i t ] = E [ ( h i t ) 2 ] − E [ h i t ] 2 = E [ ( h i t ) 2 ] − 0 = E [ ( h i t ) 2 ] Var[h_i^t]=E[(h_i^t)^2]-E[h_i^t]^2=E[(h_i^t)^2]-0=E[(h_i^t)^2] Var[hit]=E[(hit)2]−E[hit]2=E[(hit)2]−0=E[(hit)2]
所以 E [ ( w i , j t ) 2 ] = V a r [ w i , j t ] , E [ ( h j t − 1 ) 2 ] = V a r [ h j t − 1 ] E[(w_{i,j}^t)^2]=Var[w_{i,j}^t],E[(h_j^{t-1})^2]=Var[h_j^{t-1}] E[(wi,jt)2]=Var[wi,jt],E[(hjt−1)2]=Var[hjt−1]
V a r [ h i t ] = E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 ] = ∑ j V a r [ w i , j t ] V a r [ h j t − 1 ] (18) Var[h_i^t]=E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2]=\sum_jVar[w_{i,j}^t]Var[h_j^{t-1}]\tag{18} Var[hit]=E[j∑(wi,jt)2(hjt−1)2]=j∑Var[wi,jt]Var[hjt−1](18)
因为 V a r [ w i , j t ] = n t − 1 γ t Var[w_{i,j}^t]=n_{t-1}\gamma_t Var[wi,jt]=nt−1γt
V a r [ h i t ] = ∑ j V a r [ w i , j t ] V a r [ h j t − 1 ] = n t − 1 γ t V a r [ h j t − 1 ] (19) Var[h_i^t]=\sum_jVar[w_{i,j}^t]Var[h_j^{t-1}]=n_{t-1}\gamma_tVar[h_j^{t-1}]\tag{19} Var[hit]=j∑Var[wi,jt]Var[hjt−1]=nt−1γtVar[hjt−1](19)
这样我们就得到了递推公式:
V a r [ h i t ] = n t − 1 γ t V a r [ h j t − 1 ] (20) Var[h_i^t]=n_{t-1}\gamma_tVar[h_j^{t-1}]\tag{20} Var[hit]=nt−1γtVar[hjt−1](20)
那么为了保证我们的数据在训练过程中满足方差不变,那么我们只需要满足如下:
n t − 1 γ t = 1 (21) n_{t-1}\gamma_t=1\tag{21} nt−1γt=1(21) - 反向方差计算:
∂ l ∂ h t − 1 = ∂ l ∂ h t w t (22) \frac{\partial l}{\partial h^{t-1}}=\frac{\partial l}{\partial h^{t}}w^t\tag{22} ∂ht−1∂l=∂ht∂lwt(22)
( ∂ l ∂ h t − 1 ) T = [ w t ] T ( ∂ l ∂ h t ) T (23) (\frac{\partial l}{\partial h^{t-1}})^T=[w^t]^T(\frac{\partial l}{\partial h^{t}})^T\tag{23} (∂ht−1∂l)T=[wt]T(∂ht∂l)T(23)
因为我们假设了期望为 0,方差为常数 γ t \gamma_t γt
E [ ∂ l ∂ h i t − 1 ] = 0 (24) E[\frac{\partial l}{\partial h^{t-1}_i}]=0\tag{24} E[∂hit−1∂l]=0(24)
V a r [ ∂ l ∂ h i t − 1 ] = n t γ t V a r [ ∂ l ∂ h j t ] (25) Var[\frac{\partial l}{\partial h^{t-1}_i}]=n_t\gamma_tVar[\frac{\partial l}{\partial h^{t}_j}]\tag{25} Var[∂hit−1∂l]=ntγtVar[∂hjt∂l](25)
以上为迭代公式,为了保证整体的方差不变,需要满足如下:
n t γ t = 1 (26) n_t\gamma_t=1\tag{26} ntγt=1(26)
3.2 分析
我们已经通过正向和反向运算可以得出,需要满足两个条件
n
t
−
1
γ
t
=
1
(27)
n_{t-1}\gamma_t=1\tag{27}
nt−1γt=1(27)
n
t
γ
t
=
1
(28)
n_t\gamma_t=1\tag{28}
ntγt=1(28)
- 注: n t − 1 n_{t-1} nt−1是第 t 层输入的维度; n t n_t nt是第 t 层输出的维度, γ t \gamma_t γt表示第 t 层权重的方差,除非输入与输出相同,否侧无法满足上述条件,为了解决上述问题,我们引入了Xavier。
3.3 Xavier 初始化
将 公式 <27>,<28>相加后可得
(
n
t
−
1
+
n
t
)
γ
t
=
2
(29)
(n_{t-1}+n_t)\gamma_t=2\tag{29}
(nt−1+nt)γt=2(29)
γ
t
=
2
n
t
−
1
+
n
t
(30)
\gamma_t=\frac{2}{n_{t-1}+n_t}\tag{30}
γt=nt−1+nt2(30)
也就是说在给定输入输出维度时,我们希望权重满足期望为 0, 方差为
γ
t
=
2
n
t
−
1
+
n
t
\gamma_t=\frac{2}{n_{t-1}+n_t}
γt=nt−1+nt2
- 对第 t 层的权重层进行初始化权重时,初始化的常见分布如下:
正太分布:
X ∼ N ( 0 , 2 ( n t − 1 + n t ) ) (31) X\sim N(0,\sqrt{\frac{2}{(n_{t-1}+n_t)}})\tag{31} X∼N(0,(nt−1+nt)2)(31)
均匀分布:
X ∼ U ( − 6 ( n t − 1 + n t ) , 6 ( n t − 1 + n t ) ) (32) X\sim U(-\sqrt{\frac{6}{(n_{t-1}+n_t)}},\sqrt{\frac{6}{(n_{t-1}+n_t)}})\tag{32} X∼U(−(nt−1+nt)6,(nt−1+nt)6)(32) - 适配权重形状变换,特别是 n t n_t nt
4. 激活函数
假设这个 MLP 有激活函数时,分布满足期望为 0 , 方差为常数 γ t \gamma_t γt ,并且此激活函数为线性的。我们知道一般是不会选择线性激活函数。我们这里暂且这样假设,通过后续分析此合理性。
- 正向期望方差计算:
σ ( x ) = a x + β (33) \sigma(x)=ax+\beta\tag{33} σ(x)=ax+β(33)
h ′ = w t h t − 1 (34) h'=w^th^{t-1}\tag{34} h′=wtht−1(34)
h t = σ ( h ′ ) (35) h^t=\sigma(h')\tag{35} ht=σ(h′)(35)
那么我们来计算期望和方差:
E [ h i t ] = E [ a h t ′ + β ] = a E [ h t ′ ] + β = β (36) E[h_i^t]=E[ah_t'+\beta]=aE[h_t']+\beta=\beta\tag{36} E[hit]=E[aht′+β]=aE[ht′]+β=β(36)
V a r [ h i t ] = E [ ( h i t ) 2 ] − E 2 [ h i t ] = E [ ( a h t ′ + β ) 2 ] − β 2 (37) Var[h_i^t]=E[(h_i^t)^2]-E^2[h_i^t]=E[(ah_t'+\beta)^2]-{\beta}^2\tag{37} Var[hit]=E[(hit)2]−E2[hit]=E[(aht′+β)2]−β2(37)
V a r [ h i t ] = E [ a 2 ( h i ′ ) 2 + 2 a β h i ′ + β 2 ) ] − β 2 = a 2 V a r [ h i ′ ] (38) Var[h_i^t]=E[a^2(h_i')^2+2a\beta h_i'+\beta^2)]-{\beta}^2=a^2Var[h_i']\tag{38} Var[hit]=E[a2(hi′)2+2aβhi′+β2)]−β2=a2Var[hi′](38)
为了保证在经过第 t 层的后还是满足前后的两个分布是期望为零,方差不变,通过迭代公式<36>,<38>那么需要使得
β = 0 (39) \beta=0\tag{39} β=0(39)
a = 1 (40) a=1\tag{40} a=1(40)
所以线性激活函数应该为:
σ ( x ) = x (40) \sigma(x)=x\tag{40} σ(x)=x(40) - 反向期望方差计算:
线性激活函数:
σ ( x ) = a x + β (41) \sigma(x)=ax+\beta\tag{41} σ(x)=ax+β(41)
梯度关系如下:
权 重 更 新 : ∂ l ∂ h ′ = ∂ l ∂ h t ( w t ) T (42) 权重更新:\frac{\partial l}{\partial h'}=\frac{\partial l}{\partial h^t}(w^t)^T\tag{42} 权重更新:∂h′∂l=∂ht∂l(wt)T(42)
激 活 函 数 更 新 : ∂ l ∂ h t − 1 = a ∂ l ∂ h ′ (43) 激活函数更新:\frac{\partial l}{\partial h^{t-1}}=a\frac{\partial l}{\partial h'}\tag{43} 激活函数更新:∂ht−1∂l=a∂h′∂l(43)
E [ ∂ l ∂ h i t − 1 ] = 0 (44) E[\frac{\partial l }{\partial h_i^{t-1}}]=0\tag{44} E[∂hit−1∂l]=0(44)
V a r [ ∂ l ∂ h i t − 1 ] = a 2 V a r [ ∂ l ∂ h j ′ ] (45) Var[\frac{\partial l}{\partial h_i^{t-1}}]=a^2 Var[\frac{\partial l}{\partial h_j'}]\tag{45} Var[∂hit−1∂l]=a2Var[∂hj′∂l](45)
得出如下:
a = 1 ; β = 0 (46) a=1;\beta=0\tag{46} a=1;β=0(46)
5. 小结-激活函数
有上面可得,只有我们的激活函数满足了如下即可:
f
(
x
)
=
x
(47)
f(x)=x\tag{47}
f(x)=x(47)
- 常见的函数泰勒公式展开如下:
s i g m o i d ( x ) = 1 2 + x 4 − x 3 48 + O ( x 5 ) (48) sigmoid(x)=\frac{1}{2}+\frac{x}{4}-\frac{x^3}{48}+O(x^5)\tag{48} sigmoid(x)=21+4x−48x3+O(x5)(48)
t a n h ( x ) = 0 + x − x 3 3 + O ( x 5 ) (49) tanh(x)=0+x-\frac{x^3}{3}+O(x^5)\tag{49} tanh(x)=0+x−3x3+O(x5)(49)
R e L U ( x ) = 0 + x ; x ≥ 0 (50) ReLU(x)=0+x;\qquad x\geq 0\tag{50} ReLU(x)=0+x;x≥0(50)
我们发现,对于sigmoid激活函数时,局部也不满足 f(x)=x,为此我们可以进行调整可得: - 调整后的 sigmoid:
4 × s i g m o i d ( x ) − 2 = x − x 3 12 + O ( x 5 ) (51) 4\times sigmoid(x)-2=x-\frac{x^3}{12}+O(x^5)\tag{51} 4×sigmoid(x)−2=x−12x3+O(x5)(51)
这样我们就能在 x->0的附近,近似的认为 f(x) ≈ x.这样我们可以得到合理的初始值。
-总结
合 理 的 权 重 初 始 值 和 激 活 函 数 的 选 取 可 以 提 升 数 值 稳 定 性 合理的权重初始值和激活函数的选取可以提升数值稳定性 合理的权重初始值和激活函数的选取可以提升数值稳定性
5.1 权重初始化:
权重初始化的值来自于分布,分布期望 0, 方差 γ t \gamma_t γt
正太分布:
X
∼
N
(
0
,
2
(
n
t
−
1
+
n
t
)
)
(52)
X\sim N(0,\sqrt{\frac{2}{(n_{t-1}+n_t)}})\tag{52}
X∼N(0,(nt−1+nt)2)(52)
均匀分布:
X
∼
U
(
−
6
(
n
t
−
1
+
n
t
)
,
6
(
n
t
−
1
+
n
t
)
)
(53)
X\sim U(-\sqrt{\frac{6}{(n_{t-1}+n_t)}},\sqrt{\frac{6}{(n_{t-1}+n_t)}})\tag{53}
X∼U(−(nt−1+nt)6,(nt−1+nt)6)(53)
5.2 激活函数
激活函数的选择最好选择 ReLu(x),或者 近似于 f(x)=x的函数。
6. Xavier初始化的代码
为了实现上述初始化问题,我们可以选择如下函数:
-
均匀分布: torch.nn.init.uniform_
-
代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: xavier_normal
# @Create time: 2021/11/28 17:49
import torch
from torch import nn
w1 = torch.empty(3,5)
w2 = torch.empty(3,5)
print(f'w1_empty={w1}')
nn.init.xavier_normal_(w1)
print(f'w_normal={w1}')
print(f'w2_empty={w2}')
nn.init.xavier_uniform_(w2)
print(f'w_uniform={w2}')
- 结果
w1_empty=tensor([[1.4802e-15, 7.3288e-43, 1.4802e-15, 7.3288e-43, 1.4798e-15],
[7.3288e-43, 1.4798e-15, 7.3288e-43, 1.4799e-15, 7.3288e-43],
[1.4799e-15, 7.3288e-43, 1.4799e-15, 7.3288e-43, 1.4799e-15]])
w_normal=tensor([[ 0.6229, 0.3186, 0.2490, 0.5447, -0.9170],
[ 0.1360, -0.1026, 0.0904, -0.5155, 0.2935],
[-0.4435, -0.3894, -0.6188, 0.3351, -0.0992]])
w2_empty=tensor([[8.9082e-39, 5.9694e-39, 8.9082e-39, 1.0194e-38, 9.1837e-39],
[4.6837e-39, 9.2755e-39, 1.0837e-38, 8.4490e-39, 1.1112e-38],
[9.5511e-39, 1.0102e-38, 9.0919e-39, 9.9184e-39, 9.0000e-39]])
w_uniform=tensor([[ 0.2040, -0.6079, 0.2713, 0.6141, -0.5691],
[-0.7909, -0.7151, 0.3155, 0.4237, 0.4385],
[-0.3826, -0.5026, -0.7302, 0.2931, -0.5977]])