TensorFlow中RNN网络的实现和关键参数选择

主旨

TensorFlow提供了方便的API用于快速搭建和实现RNN网络。但是在实际工作中,这些API的关键参数选择令人迷惑,在没有时间详细阅读Tensorflow引用论文和源代码的条件下,仅仅靠网络上找到的样例代码决定某些参数的选择是危险且低效的。为了解决这个问题,同时不陷入过于复杂的论文和TensorFlow源代码分析,本文通过受控实验的方式,设计出一个虽然简单但是能反映出RNN基本规律的训练和测试数据,通过代码实验分析不同参数对于RNN分类精度的影响,并得出对工程有实际指导意义的结论。

运行环境和源代码

TensorFlow版本

>>> tf.version
‘1.1.0-rc2’

源代码位置:https://github.com/wangyaobupt/RNN

背景知识

RNN是递归神经网络的简称,区别于此前介绍的全连接神经网络(Full Connected Network)或者卷积神经网络(CNN),RNN的一大特点是在计算中引入了递归,即当前时刻t的输出不止由t时刻输入影响,还由t-1时刻的系统输出和系统状态影响。由于具备这样的性质,RNN在时间序列分析,特别是具备前后关联性的时间序列(例如自然语言等)非常有用。

目前LSTM是一类常用的RNN单元结构,本文不会涉及LSTM网络的原理,感兴趣的读者推荐阅读以下两篇参考资料。
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://deeplearning4j.org/lstm.html#recurrent

需要提醒注意的是:虽然LSTM名字中有“记忆”,但这里的记忆主要是为了让神经元记住此前时刻的状态,而不应该与人类的记忆能力混淆,误以为LSTM是用来记忆数据的。如果只是记忆数据,向磁盘写文件就足够好了。LSTM记忆此前时刻的状态,是为了形成一定程度上的“推理”(此处表达不够严谨):即根据对过去一段时间输入的处理结果,加上当前时刻的输入,综合分析数据特征。

LSTM典型示意图如下,示意图来自http://colah.github.io/posts/2015-08-Understanding-LSTMs/
这里写图片描述

API 和 需要确定的关键参数

TensorFlow提供了方便的API用于构造LSTM单元和网络,在本文中会用到的两个介绍如下

tf.contrib.rnn.BasicLSTMCell

根据API文档,其构造函数中num_units是没有默认值,必须由网络设计者给定。API文档中对这个参数的作用描述如下

num_units: int, The number of units in the LSTM cell

对于上述描述,笔者表示仍然看不懂,因为“units in the LSTM cell”这个概念在API文档上并没有直接定义。

为了解决这个问题,我们从TF源代码入手,分析上述API对应的源代码 core_rnn_cell_impl.py,找到如下源代码

class BasicLSTMCell(RNNCell):
  """Basic LSTM recurrent network cell.
  The implementation is based on: http://arxiv.org/abs/1409.2329.
  We add forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.
  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.
  For advanced models, please use the full LSTMCell that follows.
  "&#
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值