长短期记忆网络LSTM
上一回我们学习了RNN循环神经网络,而LSTM网络是对RNN网络做的一个改良。
由于RNN模型中存在梯度爆炸与梯度消失,尤其是梯度消失的问题非常严重,针对这个问题,LSTM做了改良。
RNN是想把所有的信息都记住,不论是有没有用的,这样就会导致遗忘掉许多东西,而LSTM设计了一个记忆细胞,具备选择性记忆的功能,可以选择记忆重要信息,过滤掉噪声信息,减轻记忆负担。
一.前向传播
1.LSTM
LSTM:
(1)单元结构
Ht-1可以理解为上一个事假的权重得分。
(2)原理
遗忘门——forget gate:
通过乘法运算,遗忘掉矩阵中0,保留1(清空不相关的学习到的知识)。
更新门——update gate:
过滤掉无关的知识
Ct:
生成新知识
输出门——output gate:
Ot:选取一部分知识解题
tanh:将学习知识转化为解题能力
2.RNN的梯度消失
RNN结构
Wx梯度反向传播公式推导:
当t越多,Ws连乘越来越趋近于无穷大(Wx很大时)或者是0(Wx很小时)。
3.LSTM的反向传播
通俗的讲就是在RNN的Wx推导出的公式,把连乘变成了加法,(RNN的梯度下降只有一条路径,而LSTM有多条路径)降低了梯度消失的影响。
也可以这么理解:∂C(t)/∂C(t-1)是多个W线性相加的综合结果,其中某个W很大或很小,也没关系,可以由其他W进行协调,当模型觉得有必要进行记忆的时候,就会尽可能使得∂C(t)/∂C(t-1)=1;
RNN首先它的∂S(t)/∂S(t-1)只包含一个W,并且是Wⁿ,当W很大或很小时,就会导致梯度爆炸或梯度消失
那也就是说LSTM这个记忆细胞的偏导更加有泛化能力,是综合了多种W,当某些W不正常时,其他W还是正常的。这让我想起了,我之前看的一篇博客,梯度正常+梯度消失=梯度正常。因为记忆细胞来自多条路径,因此它不受少部分不正常的W影响。
4.代码
关于LSTM输入,输出,参数等大家可以自行搜索。
=>博客指路:https://blog.csdn.net/weixin_41744