梯度消失,梯度爆炸_原因分析_简单例子助理解
梯度消失,梯度爆炸的根源其实是来在反向传播BP(back propagation).
反向传播的思想: 每层的输出是由两层间的权重决定的,两层之间产生的误差,按权重缩放后在网络中向前传播, 这就是反向传播。
从反向传播中得到一般化公式:
δ
n
−
1
=
ω
n
−
1
δ
n
∗
f
n
−
1
′
\delta ^{n-1}=\omega _{n-1}\delta ^{n}* f_{n-1}'
δn−1=ωn−1δn∗fn−1′
Δ w n − 2 = η δ n − 1 x n − 1 \Delta w_{n-2}=\eta \delta ^{n-1}x_{n-1} Δwn−2=ηδn−1xn−1
其中 δ n \delta ^{n} δn为第 n n n层的误差项, δ n − 1 \delta ^{n-1} δn−1为第 n − 1 n-1 n−1层的误差项, ω n − 2 \omega _{n-2} ωn−2为第 n − 2 n-2 n−2层到第 n - 1 n-1 n-1层的权重, f n − 1 ′ f_{n-1}' fn−1′为第 n − 1 n-1 n−1层输出的导数,也就是激活函数的导数, x n − 1 x_{n-1} xn−1为第 n − 1 n-1 n−1层的输入, η \eta η为学习率, Δ w n − 2 \Delta w_{n-2} Δwn−2就是第 n − 2 n-2 n−2层到第 n − 1 n-1 n−1层权重更新步长了.
对于 n n n层神经网络,根据反向传播的公式,到第 n − i n-i n−i层的权重 w n − i − 1 w_{n-i-1} wn−i−1更新规则为:
δ n − i = ( ω n − i ⋅ ⋅ ⋅ ( ω n − 2 ( ω n − 1 ( ω n δ n ∗ f n − 1 ′ ) ∗ f n − 2 ′ ) ∗ f n − 3 ′ ) ⋅ ⋅ ⋅ ∗ f n − i ′ ) \delta ^{n-i}=(\omega _{n-i}\cdot\cdot\cdot(\omega _{n-2}(\omega _{n-1}(\omega _{n}\delta ^{n}* f_{n-1}')* f_{n-2}')* f_{n-3}')\cdot\cdot\cdot* f_{n-i}') δn−i=(ωn−i⋅⋅⋅(ωn−2(ωn−1(ωnδn∗fn−1′)∗fn−2′)∗fn−3′)⋅⋅⋅∗fn−i′)
Δ w n − i − 1 = η δ n − i x n − i \Delta w_{n-i-1}=\eta \delta ^{n-i}x_{n-i} Δwn−i−1=ηδn−ixn−i
上述就是权重 w n − i w_{n-i} wn−i更新规则,对于激活函数的倒数 f n − 1 ′ f_{n-1}' fn−1′, f n − 2 ′ f_{n-2}' fn−2′, f n − 3 ′ f_{n-3}' fn−3′,.., f n − i ′ f_{n-i}' fn−i′,如果此部分大于1,那么层数增多的时候,最终的求出的权重 w n − i w_{n-i} wn−i更新将以指数形式增加,即发生梯度爆炸,如果此部分小于1,那么随着层数增多,求出的权重 w n − i − 1 w_{n-i-1} wn−i−1的更新步长 Δ w n − i − 1 \Delta w_{n-i-1} Δwn−i−1将会以指数形式衰减,即发生了梯度消失。
简单例子
用下面最简单的单线神经网络来说明,更见直观的理解梯度消失,梯度爆炸.
说明:
f
f
f表示激活函数,
f
i
f_{i}
fi就表示第
i
i
i层的输出,
δ
i
\delta ^{i}
δi表示输出的误差项.
那么根据上图,可以得到第二层的误差项
δ
2
\delta ^{2}
δ2为:
δ 2 = w 2 w 3 w 4 δ 5 f 4 ′ f 3 ′ f 2 ′ \delta ^{2}=w_{2}w_{3}w_{4}\delta ^{5}f_{4}'f_{3}'f_{2}' δ2=w2w3w4δ5f4′f3′f2′
第二层的权重更新步长为:
Δ w 2 = η δ 2 x 2 \Delta w_{2}=\eta \delta ^{2}x_{2} Δw2=ηδ2x2
从上面的例子我们可以直观的看出有连乘 f 4 ′ f 3 ′ f 2 ′ f_{4}'f_{3}'f_{2}' f4′f3′f2′, 当神经网络的层数进一步增加的时候,连乘会进一步加长.所以当 f n ′ < 1 f_{n}'<1 fn′<1的时候,随着累乘的增加(远离输出端),误差项 δ \delta δ会逐渐趋近0,这就是梯度消失.当 f n ′ > 1 f_{n}'>1 fn′>1的时候,随着累乘的增加(远离输出端),误差项 δ \delta δ会逐渐趋近无穷,这就是梯度爆炸.