大家好,今天和各位分享一下长短时记忆网络 LSTM 的原理,并使用 Pytorch 从公式上实现 LSTM 层
上一节介绍了循环神经网络 RNN,感兴趣的可以看一下:https://blog.csdn.net/dgvv4/article/details/125424902
我的这个专栏中有许多 LSTM 的实战案例,便于大家巩固知识:https://blog.csdn.net/dgvv4/category_11712004.html
1. 引言
循环神经网络的记忆功能在处理时间序列问题上存在很大优势,但随着训练的不断进行,RNN 网络一直在不断的扩充记忆,致使 RNN 产生梯度消失以及梯度爆炸。
为了解决RNN难以有效训练的问题,拥有选择记忆功能的 LSTM模型被提出。LSTM 是在 RNN 的基础上进行的改进,其既能学习数据中的长期依赖,又能解决梯度消失。LSTM 包含一个记忆单元和三个门结构,其中门结构分别是输入门、输出门和遗忘门。
LSTM 的工作过程如下:
首先由输入数据 X_t 与前一时刻隐藏层的输出数据 h_t-1 共同作用于遗忘门,遗忘门对上述信息进行筛选,记忆时间序列中的重要特征信息,丢弃无关紧要的信息;然后将输入数据 x_t 以及前一时刻隐藏层的输出数据 h_t-1 作为输入门的输入信息,进行更新;其次记忆单元通过输入数据 X_t、前一时刻隐藏层的输出数据 h_t-1 以及前一时刻的记忆单元状态 C_t-1 对自身状态进行更新;最后将输入数据 X_t、前一时刻隐藏层的输出数据 h_t-1 以及当前时刻的记忆单元状态 C_t 共同作用于输出门,输出当前时刻的隐藏层信息 h_t。
LSTM 的结构图如下: