主旨
TensorFlow提供了方便的API用于快速搭建和实现RNN网络。但是在实际工作中,这些API的关键参数选择令人迷惑,在没有时间详细阅读Tensorflow引用论文和源代码的条件下,仅仅靠网络上找到的样例代码决定某些参数的选择是危险且低效的。为了解决这个问题,同时不陷入过于复杂的论文和TensorFlow源代码分析,本文通过受控实验的方式,设计出一个虽然简单但是能反映出RNN基本规律的训练和测试数据,通过代码实验分析不同参数对于RNN分类精度的影响,并得出对工程有实际指导意义的结论。
运行环境和源代码
TensorFlow版本
>>> tf.version
‘1.1.0-rc2’
源代码位置:https://github.com/wangyaobupt/RNN
背景知识
RNN是递归神经网络的简称,区别于此前介绍的全连接神经网络(Full Connected Network)或者卷积神经网络(CNN),RNN的一大特点是在计算中引入了递归,即当前时刻t的输出不止由t时刻输入影响,还由t-1时刻的系统输出和系统状态影响。由于具备这样的性质,RNN在时间序列分析,特别是具备前后关联性的时间序列(例如自然语言等)非常有用。
目前LSTM是一类常用的RNN单元结构,本文不会涉及LSTM网络的原理,感兴趣的读者推荐阅读以下两篇参考资料。
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://deeplearning4j.org/lstm.html#recurrent
需要提醒注意的是:虽然LSTM名字中有“记忆”,但这里的记忆主要是为了让神经元记住此前时刻的状态,而不应该与人类的记忆能力混淆,误以为LSTM是用来记忆数据的。如果只是记忆数据,向磁盘写文件就足够好了。LSTM记忆此前时刻的状态,是为了形成一定程度上的“推理”(此处表达不够严谨):即根据对过去一段时间输入的处理结果,加上当前时刻的输入,综合分析数据特征。
LSTM典型示意图如下,示意图来自http://colah.github.io/posts/2015-08-Understanding-LSTMs/
API 和 需要确定的关键参数
TensorFlow提供了方便的API用于构造LSTM单元和网络,在本文中会用到的两个介绍如下
tf.contrib.rnn.BasicLSTMCell
根据API文档,其构造函数中num_units是没有默认值,必须由网络设计者给定。API文档中对这个参数的作用描述如下
num_units: int, The number of units in the LSTM cell
对于上述描述,笔者表示仍然看不懂,因为“units in the LSTM cell”这个概念在API文档上并没有直接定义。
为了解决这个问题,我们从TF源代码入手,分析上述API对应的源代码 core_rnn_cell_impl.py,找到如下源代码
class BasicLSTMCell(RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full LSTMCell that follows.
"&#