本篇文章参考于 RNN梯度消失和爆炸的原因、Towser关于LSTM如何来避免梯度弥散和梯度爆炸?的问题解答、Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass。
看本篇文章之前,建议自行学习RNN和LSTM的前向和反向传播过程,学习教程可参考刘建平老师博客循环神经网络(RNN)模型与前向反向传播算法、LSTM模型与前向反向传播算法。
具体了解LSTM如何解决RNN所带来的梯度消失问题之前,我们需要明白为什么RNN会带来梯度消失问题。
1. RNN梯度消失原因
如上图所示,为RNN模型结构,前向传播过程包括,
- 隐藏状态: h ( t ) = σ ( z ( t ) ) = σ ( U x ( t ) + W h ( t − 1 ) + b ) h^{(t)} = \sigma (z^{(t)}) = \sigma(Ux^{(t)} + Wh^{(t-1)} + b) h(t)=σ(z(t))=σ(Ux(t)+Wh(t−1)+b),此处激活函数一般为 t a n h tanh tanh。
- 模型输出: o ( t ) = V h ( t ) + c o^{(t)} = Vh^{(t)} + c o(t)=Vh(t)+c
- 预测输出: y ^ ( t ) = σ ( o ( t ) ) \hat{y}^{(t)} = \sigma(o^{(t)}) y^(t)=σ(o(t)),此处激活函数一般为softmax。
- 模型损失: L = ∑ t = 1 T L ( t ) L = \sum_{t = 1}^{T} L^{(t)} L=∑t=1TL(t)
RNN反向传播过程中,需要计算 U , V , W U, V, W U,V,W等参数的梯度,以 W W W的梯度表达式为例,
∂ L ∂ W = ∑ t = 1 T ∂ L ∂ y ( T ) ∂ y ( T ) ∂ o ( T ) ∂ o ( T ) ∂ h ( T ) ∂ h ( T ) ∂ h ( t ) ∂ h ( t ) ∂ W \frac{\partial L}{\partial W} = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \frac{\partial h^{(T)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial W} ∂W∂L=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)∂h(t)∂h(T)∂W∂h(t)
现在需要重点计算 ∂ h ( T ) ∂ h ( t ) \frac{\partial h^{(T)}}{\partial h^{(t)}} ∂h(t)∂h(T)部分,展开得到,
∂ h ( T ) ∂ h ( t ) = ∂ h ( T ) ∂ h ( T − 1 ) ∂ h ( T − 1 ) ∂ h ( T − 2 ) . . . ∂ h ( t + 1 ) ∂ h ( t ) = ∏ k = t + 1 T ∂ h ( k ) ∂ h ( k − 1 ) = ∏ k = t + 1 T t a n h ′ ( z ( k ) ) W \frac{\partial h^{(T)}}{\partial h^{(t)}} = \frac{\partial h^{(T)}}{\partial h^{(T-1)}} \frac{\partial h^{(T - 1)}}{\partial h^{(T-2)}} ...\frac{\partial h^{(t+1)}}{\partial h^{(t)}} = \prod_{k=t + 1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} = \prod_{k=t+1}^{T} tanh^{'}(z^{(k)}) W ∂h(t)∂h(T)=∂h(T−1)∂h(T)