习题6-4 推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果
LSTM(长短期记忆网络)是一种特殊的循环神经网络(RNN),旨在解决普通 RNN 在处理长序列时遇到的梯度消失和梯度爆炸问题。它通过设计多个门控机制来实现更好地学习和记忆序列中的长期依赖关系。
按照上图LSTM循环单元的结构来进行前向传播的过程
对于每个时间步t,LSTM的输入包括:
当前时间步的输入
上一时刻的隐藏状态
上一时刻的记忆单元状态
(一)LSTM的前向传播
1.LSTM的遗忘门(forget gate)
决定了上一时刻的记忆单元状态有多少比例要“遗忘”,如果遗忘门算出来的结果是0.8,是上一时刻的记忆乘以0.8,有80%的要记住,而不是80%要遗忘。
遗忘门的值:
2.LSTM的输入门(input gate)
决定了当前时刻的输入有多少比例要“更新”记忆单元
输入门的值:
3.LSTM的候选记忆单元(cell state )
生成当前时刻的新候选记忆单元
新的候选记忆单元:
4.更新记忆单元
记忆单元状态是通过遗忘门、输入门和候选记忆单元来更新的:
5.LSTM的输出门(output gate)
决定了记忆单元有多少信息可以影响到输出
输出门:
6.计算隐藏状态
隐藏状态 是通过输出门和当前时刻的记忆单元
来计算的:
7.计算输出
8.总结前向传播过程
Ok,我们已经完成了前向传播的过程,计算顺序是计算遗忘门、输入门、候选记忆单元,然后根据前一个时间步的记忆单元、遗忘门、候选记忆单元、输入门更新记忆单元。计算输出门,计算隐层输出,计算预测输出。
(二) LSTM的反向梯度推导
LSTM 的反向传播主要依赖链式法则,并且要计算每个门控的梯度。由于 LSTM 结构复杂,反向传播过程的推导也会比普通 RNN 更加复杂。
对于每个时间步 t,我们需要通过链式法则计算损失函数对 LSTM 各个参数的梯度
首先定义两种隐藏状态的梯度:
为了方便推导,给出数据在LSTM中的前向流动:
下面是自己画的: