TensorFlow系列专题(八):七步带你实现RNN循环神经网络小示例

在前面的内容里,我们已经学习了循环神经网络的基本结构和运算过程,这一小节里,我们将用TensorFlow实现简单的RNN,并且用来解决时序数据的预测问题,看一看RNN究竟能达到什么样的效果,具体又是如何实现的。


在这个演示项目里,我们使用随机生成的方式生成一个数据集(由0和1组成的二进制序列),然后人为的增加一些数据间的关系。最后我们把这个数据集放进RNN里,让RNN去学习其中的关系,实现二进制序列的预测1。数据生成的方式如下:

循环生成规模为五十万的数据集,每次产生的数据为0或1的概率均为0.5。如果连续生成了两个1(或两个0)的话,则下一个数据强制为0(或1)。

1. 我们首先导入需要的Python模块:

31bf15f8e54bd0640b5d33d0a6e57a28c77ee95d

2. 定义一个Data类,用来产生数据:

340afe74ac32d0127f3a9e77e8ef5e73278d471a

3. 在构造方法“__init__”中,我们初始化了数据集的大小“data_size”、一个batch的大小“batch_size”、一个epoch中的batch数目“num_batch”以及RNN的时间步“time_step”。接下来我们定义一个“generate_data”方法:

0da65e54c57126c04846f171ea6f4c712338af62

在第11行代码中,我们用了“np.random.choice”函数生成的由0和1组成的长串数据。接下来我们用了一个for循环,在“data_without_rel”保存的数据的基础上重新生成了一组数据,并保存在“data_with_rel”数组中。为了使生成的数据间具有一定的序列关系,我们使用了前面介绍的很简单的数据生成方式:以“data_without_rel”中的数据为参照,如果出现了连续两个1(或0)则生成一个0(或1),其它情况则以相等概率随机生成0或1。

有了数据我们接下来要用RNN去学习这些数据,看看它能不能学习到我们产生这些数据时使用的策略,即数据间的联系。评判RNN是否学习到规律以及学习的效果如何的依据,是我们在第三章里介绍过的交叉熵损失函数。根据我们生成数据的规则,如果RNN没有学习到规则,那么它预测正确的概率就是0.5,否则它预测正确的概率为:0.5*0.5+0.5*1=0.75(在“data_without_rel”中,连续出现的两个数字的组合为:00、01、10和11。00和11出现的总概率占0.5,在这种情况下,如果RNN学习到了规律,那么一定能预测出下一个数字,00对应1,11对应0。而如果出现的是01或10的话,RNN预测正确的概率就只有0.5,所以综合起来就是0.75)。

根据交叉熵损失函数,在没有学习到规律的时候,其交叉熵损失为:

loss = - (0.5 * np.log(0.5) + 0.5 * np.log(0.5)) = 0.6931471805599453

在学习到规律的时候,其交叉熵损失为:

Loss = -0.5*(0.5 * np.log(0.5) + np.log(0.5))

=-0.25 * (1 * np.log(1) ) - 0.25 * (1 *np.log(1))

=0.34657359027997264

4. 我们定义“generate_epochs”方法处理生成的数据:

ec0efaad38708718fae42245b3537bc10e1315c6

5. 接下来实现RNN部分:

bf69aeb9ced445485b429a5b789d7499b39764cd

6. 定义RNN模型:

727cd9d027fa6d3a6546897a1a5075d7b501a11e

这里我们使用了“dynamic_rnn”,因此每次会同时处理所有batch的第一组数据,总共处理的次数为:batch_size / time_step。

27d7a5c2d8c6dbd368c3808f4a493312482c4e51

7. 到这里,我们已经实现了整个RNN模型,接下来初始化相关数据,看看RNN的学习效果如何:

f82e742478a0b2def1c717577914e3495c388c17

定义数据集的大小为500000,每个batch的大小为2000,RNN的“时间步”设为5,隐藏层的神经元数目为6。将训练过程中的loss可视化,结果如下图中的左侧图像所示:

9a2084aaa85516602b8114469279f8c39c5df251

图1 二进制序列数据训练的loss曲线

从左侧loss曲线可以看到,loss最终稳定在了0.35左右,这与我们之前的计算结果一致,说明RNN学习到了序列数据中的规则。右侧的loss曲线是在调整了序列关系的时间间隔后(此时的time_step过小,导致RNN无法学习到序列数据的规则)的结果,此时loss稳定在0.69左右,与之前的计算也吻合。


原文发布时间为:2018-11-15

本文来自云栖社区合作伙伴“磐创AI”,了解相关信息可以关注“磐创AI”。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值