对于RNN和LSTM不了解的朋友,可以去看看这两篇入门介绍,写的非常棒,在此特别感谢两位作者!!
RNN入门:https://zhuanlan.zhihu.com/p/28054589
LSTM入门:http://colah.github.io/posts/2015-08-Understanding-LSTMs
本文参照了:https://jasdeep06.github.io/posts/Understanding-LSTM-in-Tensorflow-MNIST/ 对这篇文章结合自己的理解作出翻译。
1.MNIST数据集结构
- Training data(
mnist.train
)-55000 个训练数据集 - Test data(
mnist.test
)-10000 个测试数据集 - Validation data(
mnist.validation
)-5000个验证数据集
每个类别又分为了images和labels,也就是图片以及标签,每张图片都是(28*28*1)的,数据集中将图片的特征值压缩为(number,784)。LSTMs通常适用于复杂的序列问题像自然语言处理这类的问题,但这种问题本身就不太好理解,我们的主要目标是去理解LSTMs在tensorflow中具体实现细节,比如处理输入格式,LSTM的cell运作以及对网络模型的整体设计。MNIST就是一个不错的选择。
2.Implementation
首先给出一张RNN网络图,理解了这张图再去实现代码就会更直观。
- xt 代表了每个时间节点的输入
- st 代表了在t时间点的隐藏单元 这也成为网络的记忆 memory
- ot 代表每个时间点的输出
- U,V and W 是每个时间点都共享的参数