读书笔记-数值稳定性和模型初始化

1. 背景

我们在训练神经网络的过程中,需要经常对神经网络进行随机初始化。但当神经网络不复杂的时候,我们可以不用太关心参数初始化。然而对于深度神经网络来说,初始化方案对于训练的收敛性起到至关作用。糟糕的初始化会让发生梯度爆炸和梯度消失。

2. 梯度消失&梯度爆炸

假设我们有 d 层的深度神经网络, t 表示层数。
h t = f t ( h t − 1 ) (1) h^t=f_t(h^{t-1})\tag{1} ht=ft(ht1)(1)
y = l ⋅ f d ⋅ . . . ⋅ f 1 ( x ) (2) y=l·f_d·...·f_1(x)\tag{2} y=lfd...f1(x)(2)
那么我们可以计算损失值 l 关于权重参数 w 的导数如下:
∂ l ∂ w t = ∂ l ∂ h d ⋅ ∂ h d ∂ h d − 1 . . . ∂ h t + 1 ∂ h t ⏟ d − t 次 矩 阵 乘 法 ⋅ ∂ h t ∂ w t (3) \frac{\partial l}{\partial w^t}=\frac{\partial l}{\partial h^d}·\underbrace{\frac{\partial h^d}{\partial h^{d-1}}...\frac{\partial h^{t+1}}{\partial h^{t}}}_{d-t次矩阵乘法}·\frac{\partial h^t}{\partial w^t}\tag{3} wtl=hdldt hd1hd...htht+1wtht(3)

因为要进行 d-t 次矩阵乘法,那么如果每一个梯度值为 m 。
1. 5 100 ≈ 4 × 1 0 17 → 梯 度 爆 炸 (4) 1.5^{100}≈4×10^{17}\rightarrow 梯度爆炸\tag{4} 1.51004×1017(4)
0. 8 100 ≈ 2 × 1 0 − 10 → 梯 度 消 失 (5) 0.8^{100}≈2×10^{-10}\rightarrow 梯度消失\tag{5} 0.81002×1010(5)

  • 梯度爆炸:导致了梯度超过计算机值的范围,造成上溢
  • 梯度消失:导致了梯度十分的小,导致神经网络无法更新相关参数
    为了解决上述梯度爆炸和梯度消失问题,我们希望在每层梯度都在合理范围内。
  • MLP多层感知机举例
    假设第 t 层的函数如下:
    f t ( h t − 1 ) = σ ( w t h t − 1 ) ; σ 是 激 活 函 数 (6) f_t(h^{t-1})=\sigma(w^th^{t-1});\sigma是激活函数\tag{6} ft(ht1)=σ(wtht1);σ(6)
    对权重 w t w_t wt求导可得:
    ∂ h t ∂ h t − 1 = d i a g ( σ ′ ( w t h t − 1 ) ) ( w t ) T ; σ ′ 是 σ 的 导 数 函 数 (7) \frac{\partial h^t}{\partial h^{t-1}}=diag(\sigma'(w^th^{t-1}))(w^t)^T;\sigma'是\sigma的导数函数\tag{7} ht1ht=diag(σ(wtht1))(wt)T;σσ(7)
    ∏ i = t d − 1 ∂ h t − 1 ∂ h i = ∏ i = t d − 1 d i a g ( σ ′ ( w t h t − 1 ) ) ( w t ) T (8) \prod_{i=t}^{d-1}\frac{\partial h^{t-1}}{\partial h^i}=\prod_{i=t}^{d-1}diag(\sigma'(w^th^{t-1}))(w^t)^T\tag{8} i=td1hiht1=i=td1diag(σ(wtht1))(wt)T(8)
  • 假设使用ReLU作为激活函数,那么可以得到:
    σ ( x ) = m a x ( 0 , x ) (9) \sigma(x)=max(0,x)\tag{9} σ(x)=max(0,x)(9)

σ ′ ( x ) = { 1 , i f x > 0 0 ; o t h e r w i s e (10) \sigma'(x)=\left\{\begin{array}{l} 1 ,\qquad if \quad x>0 \\ 0 ;\qquad otherwise \end{array}\right.\tag{10} σ(x)={1,ifx>00;otherwise(10)

  • 注:当我们的值大于零时候, ∏ i = t d − 1 ∂ h t − 1 ∂ h i \prod_{i=t}^{d-1}\frac{\partial h^{t-1}}{\partial h^i} i=td1hiht1中的一些元素的值就由 ∏ i = t d − 1 ( w i ) T \prod_{i=t}^{d-1}(w^i)^T i=td1(wi)T来决定,如果 d-t很大,那么这个连乘值就非常的大,从而导致梯度爆炸或者梯度消失。

3. 模型初始化

为了解决梯度爆炸和梯度消失,我们希望对找到一个合理的区间进行初始化,常见思路如下:

  • 目标:让梯度值在合理的范围内,例如[le-6,le3]
  • 将乘法变加法: ResNet,LSTM
  • 归一化:梯度归一化,梯度裁剪
  • 合理的权重初始和激活函数 <重点>

3.1 目标

为了解决上述问题,我们希望是:让每层的方差是一个常数

  • 将每层的输出和梯度都看作是随机变量
  • 让它们的均值和方差都保持一致
  • 正向期望和方差:
    E [ h i t ] = 0 ; V a r [ h i t ] = a (11) E[h^t_i]=0;\quad Var[h_i^t]=a\tag{11} E[hit]=0;Var[hit]=a(11)
  • 反向期望和方差:
    E [ ∂ l ∂ h i t ] = 0 ; V a r [ ∂ l ∂ h i t ] = b ; ∀ i , t (12) E[\frac{\partial l}{\partial h^t_i}]=0;\quad Var[\frac{\partial l}{\partial h^t_i}]=b;\quad \forall i,t\tag{12} E[hitl]=0;Var[hitl]=b;i,t(12)
  • 注:其中 a,b 都是常数
  • 权重初始化
    在合理值区间里随机初始参数
    训练开始的时候更容易有数值不稳定。常见思路如下: a .远离最优解的地方损失函数表面可能很复杂;b.最优解附近表面会比较平;c.使用N(0,0.01)来初始可能对小玩过没问题,但不能保证深度神经网络
  • 我们以 MLP 为例来讲解下:

假设<1>
权重服从独立同分布,即 w i , j t w_{i,j}^t wi,jt是 i.i.d,且假设均值为 0 ,方差为 γ t \gamma_t γt,且 h i t − 1 h_i^{t-1} hit1独立于 w i , j t w_{i,j}^t wi,jt
E [ w i , j t ] = 0 ; V a r [ w i , j t ] = γ t (13) E[w_{i,j}^t]=0;\quad Var[w_{i,j}^t]=\gamma_t\tag{13} E[wi,jt]=0;Var[wi,jt]=γt(13)
假设<2>
假设这个 MLP 没有激活函数,满足 h t = w t h t − 1 h^t=w^th^{t-1} ht=wtht1,这里 w t ∈ R n t × n t − 1 w^t \in R^{n_t×n_{t-1}} wtRnt×nt1
E [ h i t ] = E [ ∑ j w i , j t h j t − 1 ] = ∑ j E [ w i , j t ] E [ h j t − 1 ] = 0 (14) E[h_i^t]=E[\sum_jw^t_{i,j}h_j^{t-1}]=\sum_jE[w^t_{i,j}]E[h_j^{t-1}]=0\tag{14} E[hit]=E[jwi,jthjt1]=jE[wi,jt]E[hjt1]=0(14)

  • 正向方差计算:
    V a r [ h i t ] = E [ ( h i t ) 2 ] − E [ h i t ] 2 = E [ ( h i t ) 2 ] − 0 = E [ ( ∑ j w i , j t h j t − 1 ) 2 ] (15) Var[h_i^t]=E[(h_i^t)^2]-E[h_i^t]^2=E[(h_i^t)^2]-0=E[(\sum_jw^t_{i,j}h_j^{t-1})^2]\tag{15} Var[hit]=E[(hit)2]E[hit]2=E[(hit)2]0=E[(jwi,jthjt1)2](15)
    展开上式可得:
    = E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 + ∑ j ≠ k w i , j t w i , k t h j t − 1 h k t − 1 ] (16) =E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2+\sum_{j≠k}w_{i,j}^tw_{i,k}^th_j^{t-1}h_k^{t-1}]\tag{16} =E[j(wi,jt)2(hjt1)2+j=kwi,jtwi,kthjt1hkt1](16)
    因为 E [ h i t ] = 0 E[h_i^t]=0 E[hit]=0
    = E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 ] = ∑ j E [ ( w i , j t ) 2 ] E [ ( h j t − 1 ) 2 ] (17) =E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2]=\sum_jE[(w_{i,j}^t)^2]E[(h_j^{t-1})^2]\tag{17} =E[j(wi,jt)2(hjt1)2]=jE[(wi,jt)2]E[(hjt1)2](17)
    因为 V a r [ h i t ] = E [ ( h i t ) 2 ] − E [ h i t ] 2 = E [ ( h i t ) 2 ] − 0 = E [ ( h i t ) 2 ] Var[h_i^t]=E[(h_i^t)^2]-E[h_i^t]^2=E[(h_i^t)^2]-0=E[(h_i^t)^2] Var[hit]=E[(hit)2]E[hit]2=E[(hit)2]0=E[(hit)2]
    所以 E [ ( w i , j t ) 2 ] = V a r [ w i , j t ] , E [ ( h j t − 1 ) 2 ] = V a r [ h j t − 1 ] E[(w_{i,j}^t)^2]=Var[w_{i,j}^t],E[(h_j^{t-1})^2]=Var[h_j^{t-1}] E[(wi,jt)2]=Var[wi,jt],E[(hjt1)2]=Var[hjt1]
    V a r [ h i t ] = E [ ∑ j ( w i , j t ) 2 ( h j t − 1 ) 2 ] = ∑ j V a r [ w i , j t ] V a r [ h j t − 1 ] (18) Var[h_i^t]=E[\sum_j(w_{i,j}^t)^2(h_j^{t-1})^2]=\sum_jVar[w_{i,j}^t]Var[h_j^{t-1}]\tag{18} Var[hit]=E[j(wi,jt)2(hjt1)2]=jVar[wi,jt]Var[hjt1](18)
    因为 V a r [ w i , j t ] = n t − 1 γ t Var[w_{i,j}^t]=n_{t-1}\gamma_t Var[wi,jt]=nt1γt
    V a r [ h i t ] = ∑ j V a r [ w i , j t ] V a r [ h j t − 1 ] = n t − 1 γ t V a r [ h j t − 1 ] (19) Var[h_i^t]=\sum_jVar[w_{i,j}^t]Var[h_j^{t-1}]=n_{t-1}\gamma_tVar[h_j^{t-1}]\tag{19} Var[hit]=jVar[wi,jt]Var[hjt1]=nt1γtVar[hjt1](19)
    这样我们就得到了递推公式:
    V a r [ h i t ] = n t − 1 γ t V a r [ h j t − 1 ] (20) Var[h_i^t]=n_{t-1}\gamma_tVar[h_j^{t-1}]\tag{20} Var[hit]=nt1γtVar[hjt1](20)
    那么为了保证我们的数据在训练过程中满足方差不变,那么我们只需要满足如下:
    n t − 1 γ t = 1 (21) n_{t-1}\gamma_t=1\tag{21} nt1γt=1(21)
  • 反向方差计算:
    ∂ l ∂ h t − 1 = ∂ l ∂ h t w t (22) \frac{\partial l}{\partial h^{t-1}}=\frac{\partial l}{\partial h^{t}}w^t\tag{22} ht1l=htlwt(22)
    ( ∂ l ∂ h t − 1 ) T = [ w t ] T ( ∂ l ∂ h t ) T (23) (\frac{\partial l}{\partial h^{t-1}})^T=[w^t]^T(\frac{\partial l}{\partial h^{t}})^T\tag{23} (ht1l)T=[wt]T(htl)T(23)
    因为我们假设了期望为 0,方差为常数 γ t \gamma_t γt
    E [ ∂ l ∂ h i t − 1 ] = 0 (24) E[\frac{\partial l}{\partial h^{t-1}_i}]=0\tag{24} E[hit1l]=0(24)
    V a r [ ∂ l ∂ h i t − 1 ] = n t γ t V a r [ ∂ l ∂ h j t ] (25) Var[\frac{\partial l}{\partial h^{t-1}_i}]=n_t\gamma_tVar[\frac{\partial l}{\partial h^{t}_j}]\tag{25} Var[hit1l]=ntγtVar[hjtl](25)
    以上为迭代公式,为了保证整体的方差不变,需要满足如下:
    n t γ t = 1 (26) n_t\gamma_t=1\tag{26} ntγt=1(26)

3.2 分析

我们已经通过正向和反向运算可以得出,需要满足两个条件
n t − 1 γ t = 1 (27) n_{t-1}\gamma_t=1\tag{27} nt1γt=1(27)
n t γ t = 1 (28) n_t\gamma_t=1\tag{28} ntγt=1(28)

  • 注: n t − 1 n_{t-1} nt1是第 t 层输入的维度; n t n_t nt是第 t 层输出的维度, γ t \gamma_t γt表示第 t 层权重的方差,除非输入与输出相同,否侧无法满足上述条件,为了解决上述问题,我们引入了Xavier。

3.3 Xavier 初始化

将 公式 <27>,<28>相加后可得
( n t − 1 + n t ) γ t = 2 (29) (n_{t-1}+n_t)\gamma_t=2\tag{29} (nt1+nt)γt=2(29)
γ t = 2 n t − 1 + n t (30) \gamma_t=\frac{2}{n_{t-1}+n_t}\tag{30} γt=nt1+nt2(30)
也就是说在给定输入输出维度时,我们希望权重满足期望为 0, 方差为 γ t = 2 n t − 1 + n t \gamma_t=\frac{2}{n_{t-1}+n_t} γt=nt1+nt2

  • 对第 t 层的权重层进行初始化权重时,初始化的常见分布如下:
    正太分布:
    X ∼ N ( 0 , 2 ( n t − 1 + n t ) ) (31) X\sim N(0,\sqrt{\frac{2}{(n_{t-1}+n_t)}})\tag{31} XN(0,(nt1+nt)2 )(31)
    均匀分布:
    X ∼ U ( − 6 ( n t − 1 + n t ) , 6 ( n t − 1 + n t ) ) (32) X\sim U(-\sqrt{\frac{6}{(n_{t-1}+n_t)}},\sqrt{\frac{6}{(n_{t-1}+n_t)}})\tag{32} XU((nt1+nt)6 ,(nt1+nt)6 )(32)
  • 适配权重形状变换,特别是 n t n_t nt

4. 激活函数

假设这个 MLP 有激活函数时,分布满足期望为 0 , 方差为常数 γ t \gamma_t γt ,并且此激活函数为线性的。我们知道一般是不会选择线性激活函数。我们这里暂且这样假设,通过后续分析此合理性。

  • 正向期望方差计算:
    σ ( x ) = a x + β (33) \sigma(x)=ax+\beta\tag{33} σ(x)=ax+β(33)
    h ′ = w t h t − 1 (34) h'=w^th^{t-1}\tag{34} h=wtht1(34)
    h t = σ ( h ′ ) (35) h^t=\sigma(h')\tag{35} ht=σ(h)(35)
    那么我们来计算期望和方差:
    E [ h i t ] = E [ a h t ′ + β ] = a E [ h t ′ ] + β = β (36) E[h_i^t]=E[ah_t'+\beta]=aE[h_t']+\beta=\beta\tag{36} E[hit]=E[aht+β]=aE[ht]+β=β(36)
    V a r [ h i t ] = E [ ( h i t ) 2 ] − E 2 [ h i t ] = E [ ( a h t ′ + β ) 2 ] − β 2 (37) Var[h_i^t]=E[(h_i^t)^2]-E^2[h_i^t]=E[(ah_t'+\beta)^2]-{\beta}^2\tag{37} Var[hit]=E[(hit)2]E2[hit]=E[(aht+β)2]β2(37)
    V a r [ h i t ] = E [ a 2 ( h i ′ ) 2 + 2 a β h i ′ + β 2 ) ] − β 2 = a 2 V a r [ h i ′ ] (38) Var[h_i^t]=E[a^2(h_i')^2+2a\beta h_i'+\beta^2)]-{\beta}^2=a^2Var[h_i']\tag{38} Var[hit]=E[a2(hi)2+2aβhi+β2)]β2=a2Var[hi](38)
    为了保证在经过第 t 层的后还是满足前后的两个分布是期望为零,方差不变,通过迭代公式<36>,<38>那么需要使得
    β = 0 (39) \beta=0\tag{39} β=0(39)
    a = 1 (40) a=1\tag{40} a=1(40)
    所以线性激活函数应该为:
    σ ( x ) = x (40) \sigma(x)=x\tag{40} σ(x)=x(40)
  • 反向期望方差计算:
    线性激活函数:
    σ ( x ) = a x + β (41) \sigma(x)=ax+\beta\tag{41} σ(x)=ax+β(41)
    梯度关系如下:
    权 重 更 新 : ∂ l ∂ h ′ = ∂ l ∂ h t ( w t ) T (42) 权重更新:\frac{\partial l}{\partial h'}=\frac{\partial l}{\partial h^t}(w^t)^T\tag{42} hl=htl(wt)T(42)
    激 活 函 数 更 新 : ∂ l ∂ h t − 1 = a ∂ l ∂ h ′ (43) 激活函数更新:\frac{\partial l}{\partial h^{t-1}}=a\frac{\partial l}{\partial h'}\tag{43} ht1l=ahl(43)
    E [ ∂ l ∂ h i t − 1 ] = 0 (44) E[\frac{\partial l }{\partial h_i^{t-1}}]=0\tag{44} E[hit1l]=0(44)
    V a r [ ∂ l ∂ h i t − 1 ] = a 2 V a r [ ∂ l ∂ h j ′ ] (45) Var[\frac{\partial l}{\partial h_i^{t-1}}]=a^2 Var[\frac{\partial l}{\partial h_j'}]\tag{45} Var[hit1l]=a2Var[hjl](45)
    得出如下:
    a = 1 ; β = 0 (46) a=1;\beta=0\tag{46} a=1;β=0(46)

5. 小结-激活函数

有上面可得,只有我们的激活函数满足了如下即可:
f ( x ) = x (47) f(x)=x\tag{47} f(x)=x(47)

  • 常见的函数泰勒公式展开如下:
    s i g m o i d ( x ) = 1 2 + x 4 − x 3 48 + O ( x 5 ) (48) sigmoid(x)=\frac{1}{2}+\frac{x}{4}-\frac{x^3}{48}+O(x^5)\tag{48} sigmoid(x)=21+4x48x3+O(x5)(48)
    t a n h ( x ) = 0 + x − x 3 3 + O ( x 5 ) (49) tanh(x)=0+x-\frac{x^3}{3}+O(x^5)\tag{49} tanh(x)=0+x3x3+O(x5)(49)
    R e L U ( x ) = 0 + x ; x ≥ 0 (50) ReLU(x)=0+x;\qquad x\geq 0\tag{50} ReLU(x)=0+x;x0(50)
    我们发现,对于sigmoid激活函数时,局部也不满足 f(x)=x,为此我们可以进行调整可得:
  • 调整后的 sigmoid:
    4 × s i g m o i d ( x ) − 2 = x − x 3 12 + O ( x 5 ) (51) 4\times sigmoid(x)-2=x-\frac{x^3}{12}+O(x^5)\tag{51} 4×sigmoid(x)2=x12x3+O(x5)(51)
    这样我们就能在 x->0的附近,近似的认为 f(x) ≈ x.这样我们可以得到合理的初始值。
    -总结
    合 理 的 权 重 初 始 值 和 激 活 函 数 的 选 取 可 以 提 升 数 值 稳 定 性 合理的权重初始值和激活函数的选取可以提升数值稳定性

5.1 权重初始化:

权重初始化的值来自于分布,分布期望 0, 方差 γ t \gamma_t γt

正太分布:
X ∼ N ( 0 , 2 ( n t − 1 + n t ) ) (52) X\sim N(0,\sqrt{\frac{2}{(n_{t-1}+n_t)}})\tag{52} XN(0,(nt1+nt)2 )(52)
均匀分布:
X ∼ U ( − 6 ( n t − 1 + n t ) , 6 ( n t − 1 + n t ) ) (53) X\sim U(-\sqrt{\frac{6}{(n_{t-1}+n_t)}},\sqrt{\frac{6}{(n_{t-1}+n_t)}})\tag{53} XU((nt1+nt)6 ,(nt1+nt)6 )(53)

5.2 激活函数

激活函数的选择最好选择 ReLu(x),或者 近似于 f(x)=x的函数。

6. Xavier初始化的代码

为了实现上述初始化问题,我们可以选择如下函数:

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: xavier_normal
# @Create time: 2021/11/28 17:49
import torch
from torch import nn

w1 = torch.empty(3,5)
w2 = torch.empty(3,5)
print(f'w1_empty={w1}')
nn.init.xavier_normal_(w1)
print(f'w_normal={w1}')
print(f'w2_empty={w2}')
nn.init.xavier_uniform_(w2)
print(f'w_uniform={w2}')
  • 结果
w1_empty=tensor([[1.4802e-15, 7.3288e-43, 1.4802e-15, 7.3288e-43, 1.4798e-15],
        [7.3288e-43, 1.4798e-15, 7.3288e-43, 1.4799e-15, 7.3288e-43],
        [1.4799e-15, 7.3288e-43, 1.4799e-15, 7.3288e-43, 1.4799e-15]])
w_normal=tensor([[ 0.6229,  0.3186,  0.2490,  0.5447, -0.9170],
        [ 0.1360, -0.1026,  0.0904, -0.5155,  0.2935],
        [-0.4435, -0.3894, -0.6188,  0.3351, -0.0992]])
w2_empty=tensor([[8.9082e-39, 5.9694e-39, 8.9082e-39, 1.0194e-38, 9.1837e-39],
        [4.6837e-39, 9.2755e-39, 1.0837e-38, 8.4490e-39, 1.1112e-38],
        [9.5511e-39, 1.0102e-38, 9.0919e-39, 9.9184e-39, 9.0000e-39]])
w_uniform=tensor([[ 0.2040, -0.6079,  0.2713,  0.6141, -0.5691],
        [-0.7909, -0.7151,  0.3155,  0.4237,  0.4385],
        [-0.3826, -0.5026, -0.7302,  0.2931, -0.5977]])
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值