GRU与LSTM 简单程序

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达
    
    

fb614855753e04f4bc421bfb7f4899c3.jpeg

编辑 | 安可

出品 | 磐创AI技术团队

目录:

  • 门控循环神经网络简介

  • 长短期记忆网络(LSTM)

  • 门控制循环单元(GRU)

  • TensorFlow实现LSTM和GRU

  • 参考文献

一、 门控循环神经网络

门控循环神经网络在简单循环神经网络的基础上对网络的结构做了调整,加入了门控机制,用来控制神经网络中信息的传递。门控机制可以用来控制记忆单元中的信息有多少需要保留,有多少需要丢弃,新的状态信息又有多少需要保存到记忆单元中等。这使得门控循环神经网络可以学习跨度相对较长的依赖关系,而不会出现梯度消失和梯度爆炸的问题。如果从数学的角度来理解,一般结构的循环神经网络中,网络的状态6279ddcb624e9b674bf238a36582093e.pngc3078dac0db1ba66129ba68b06cf8597.png之间是非线性的关系,并且参数W在每个时间步共享,这是导致梯度爆炸和梯度消失的根本原因。门控循环神经网络解决问题的方法就是在状态fb4b2b926d727fc5ccc39ec02a01fee3.pngf5b8f7e612d09b65e58f0c5fa0835794.png之间添加一个线性的依赖关系,从而避免梯度消失或梯度爆炸的问题。

二、 长短期记忆网络(LSTM)

长短期记忆网络(Long Short-term Memory,简称LSTM)的结构如图1所示,LSTM[1]的网络结构看上去很复杂,但实际上如果将每一部分拆开来看,其实也很简单。在一般的循环神经网络中,记忆单元没有衡量信息的价值量的能力,因此,记忆单元对于每个时刻的状态信息等同视之,这就导致了记忆单元中往往存储了一些无用的信息,而真正有用的信息却被这些无用的信息挤了出去。LSTM正是从这一点出发做了相应改进,和一般结构的循环神经网络只有一种网络状态不同,LSTM中将网络的状态分为内部状态和外部状态两种。LSTM的外部状态类似于一般结构的循环神经网络中的状态,即该状态既是当前时刻隐藏层的输出,也是下一时刻隐藏层的输入。这里的内部状态则是LSTM特有的。

在LSTM中有三个称之为“门”的控制单元,分别是输入门(input gate)、输出门(output gate)和遗忘门(forget gate),其中输入门和遗忘门是LSTM能够记忆长期依赖的关键。输入门决定了当前时刻网络的状态有多少信息需要保存到内部状态中,而遗忘门则决定了过去的状态信息有多少需要丢弃。最后,由输出门决定当前时刻的内部状态有多少信息需要输出给外部状态。

9ba90f42ff1e609a08c4a7393717f231.jpeg

图1 单个时间步的LSTM网络结构示意图

从上图我们可以看到,一个LSTM单元在每个时间步都会接收三个输入,当前时刻的输入,来自上一时刻的内部状态42c146f60bb9d0b5dcc320b84de1d97e.png以及上一时刻的外部状态dc02d9da9024f0e40522c0c64a1f506a.png。其中,82876f7fccef63019618b1f1981c0715.png59c0f9c60c558391d7dbd3364c32f766.png同时作为三个“门”的输入。3609231b3a5722572f733b1901fa1a39.png为Logistic函数。

接下来我们将分别介绍LSTM中的几个“门”结构。首先看一下输入门,如图2所示:

2226c9dc6235aa975b01f76549260477.jpeg

图2 LSTM的输入门结构示意图

LSTM中也有类似于RNN(这里特指前面介绍过的简单结构的循环神经网络)的前向计算过程,如图2,如果去掉输入门部分,剩下的部分其实就是RNN中输入层到隐藏层的结构,“tanh”可以看作是隐藏层的激活函数,从“tanh”节点输出的值为:

3ac28c1f85fd45ee1e37c7cab70ab540.png

式1

上式中,参数的下标“c”代表这是“tanh”节点的参数,同理,输入门参数的下标为“i”,输出门参数的下标为“o”,遗忘门参数的下标为“f”。上式与简单结构循环神经网络中隐藏层的计算公式一样。在LSTM中,我们将“tanh”节点的输出称为候选状态49b8c1d111a2fd0548de17531a35e260.png

输入门是如何实现其控制功能的?输入门的计算公式如下:

b84dcee3317d3069519e00bdec5f213b.png

式2

由于834c22d689a61e4dc50a8fd146748ff9.png为Logistic函数,其值域为(0,1),因此输入门的值就属于(0,1)。LSTM将“tanh”节点的输出(即候选状态641f8ae40b8f6c5ff1f919abc6375ff6.png)乘上输入门的值后再用来更新内部状态。如果的值趋向于0的话,那么候选状态75423a2cb70bf2ed409f74681cb171b9.png就只有极少量的信息会保存到内部状态中,相反的,如果的值e6d519cb45595f19daf4f6eef43a308a.png趋近于1,那么候选状态aca6f2caffaf8fff90a43f8eb13b12a0.png就会有更多的信息被保存。输入门就是通过这种方法来决定保存多少中的信息,2c600cca81002210947d123bec0290e1.png值的大小就代表了新信息的重要性,不重要的信息就不会被保存到内部状态中.

再来看遗忘门,如图3所示:

47aa29608eadc9cb01da382e5e60ec31.jpeg

图3 LSTM的遗忘门结构示意图

遗忘门的计算公式如下:

7c6600153ce09855337160f93d736d4d.png

式3

和输入门是同样的方法,通过的值来控制上一时刻的内部状态有多少信息需要“遗忘”。当fbbee811dfc155660868bd4003b6489b.png的值越趋近于0,被遗忘的信息越多。同样的原理,我们来看“输出门”,如图4所示。输出门的计算公式如下:

286253a1133231fd7195dfc0440725a8.png

式4

59bf83cfdf063e1712387b8113a46493.png的值月接近于1,则当前时刻的内部状态a331a528125547642f53037b6cd2947f.png就会有更多的信息输出给当前时刻的外部状态fac74cc5962057c06632167808274168.png

e6f969c955fed89715680ee380861611.png

图4 LSTM的输出门结构示意图

以上就是LSTM的整个网络结构以及各个“门”的计算公式。通过选择性的记忆和遗忘状态信息,使的LSTM要比一般的循环神经网络能够学习更长时间间隔的依赖关系。根据不同的需求,LSTM还有着很多不同的变体版本,这些版本的网络结构大同小异,但都在其特定的应用中表现出色。

三、 门控制循环单元(GRU)

门控制循环单元(gated recurrent unit,GRU)网络是另一种基于门控制的循环神经网络,GRU[2]的网络结构相比LSTM要简单一些。GRU将LSTM中的输入门和遗忘门合并成了一个门,称为更新门(update gate)。在GRU网络中,没有LSTM网络中的内部状态和外部状态的划分,而是通过直接在当前网络的状态67a03b542d1d4059e4778eb8bdf0b75f.png和上一时刻网络的状态41ff133fdb5045b45dc709ae5a58c1f8.png之间添加一个线性的依赖关系,来解决梯度消失和梯度爆炸的问题。

043676394be4377bb9a17e7db7422723.jpeg

图5 单个时间步的GRU网络结构示意图

在GRU网络中,更新门用来控制当前时刻输出的状态74875ca9ef33ce8fb5ffc26722354a6c.png中要保留多少历史状态d903df612ed764f5aa01212b2e07022e.png,以及保留多少当前时刻的候选状态564ad015bb649bea42d86911163e13ea.png。更新门的计算公式如下:

66f4b1af9ecf9715460952e4c89785d8.png

式5

如图5所示,更新门的输出分别和历史状态35457ad0e4e95df88f777ca9fefe1175.png以及候选状态39799204cebb9b86c38afda14fed7042.png进行了乘操作,其中和0c546f7b39c3478c9146b54933bd296a.png相乘的是6c8e0ffe8a3a33a420ecc21e6d5719f7.png。最终当前时刻网络的输出为:

15a60ab4a9a47f0e2ec762bd587f77c0.png

式6

重置门的作用是决定当前时刻的候选状态是否需要依赖上一时刻的网络状态以及需要依赖多少。从图5可以看到,上一时刻的网络状态c46a94d91f192b77f109a30daf84adf1.png先和重置门的输出相乘之后,再作为参数用于计算当前时刻的候选状态。重置门的计算公式如下:

b96b026c0e697e09454eb3a3c2b0cfa2.png

式7

caef23b5fafd671ef4778a13d5a7d1c5.png的值决定了候选状态1bfa2dbbf7da92f0b39f149c80e54acd.png对上一时刻的状态b7221af61e1c312091bb892b12a587fa.png的依赖程度,候选状态106ac8a820e80e581bd00e733eb23d33.png的计算公式如下:

e92695b1137b034a5db1acc523f0dcbd.png

式8

其实当c46732c64147fa76546eb2cea41ec8d2.png的值为0且06817db371cd33301ebdeeb6a279542e.png的值为1时,GRU网络中的更新门和重置门就不再发挥作用了,而此时的GRU网络就退化成了简单循环神经网络,因为此时有:

9f637aff46d5c88981e326fca5395435.png

式9

四、 TensorFlow实现LSTM和GRU

前面介绍了LSTM和GRU的理论知识,这一小节里我们使用TensorFlow来实现一个LSTM模型。为了方便,这里我们使用前面介绍过的mnist数据集。可能读者对于在循环神经网络中使用图像数据会有一点疑惑,因为通常情况下图像数据一般都是使用卷积神经网络来训练。事实的确是这样,由于卷积神经网络和循环神经网络的结构不同,也就使得它们各自有不同的适用场景,但这不代表卷积神经网络只能用来处理时序数据,同样也不能认为循环神经网络不能用来处理图像数据,只要在输入数据的格式上稍作调整即可,就像上一章中我们使用卷积神经网络网络来处理文本数据一样。

mnist数据集我们在第三章中就已经使用过,这里就不再多做介绍了,直接上代码:

84ac78000de918d9c648f844217c2aab.png

我们首先导入需要的包,然后定义了神经网络中的一些相关参数。其中第6行代码定义了LSTM中的时间步的长度,由于我们mnist数据集的图像大小为28X28,所以我们将一行像素作为一个输入,这样我们就需要有28个时间步。第7行代码定义了每个时间步输入数据的长度(每个时间步的输入是一个向量),即一行像素的长度。

3213fc22138728991febfc84f0e4985f.png

第10行代码用来加载mnist数据集,并通过参数“validation_size”指定了验证集的大小。第16行代码用来将mnist数据集的格式转换成“dynamic_rnn”函数接受的数据格式“[batch_size, max_time,data_length]”。

465fbbb930eaacaedc5b4a45d4505f97.jpeg

在上面的代码中,我们定义了一个两层的LSTM网络结构,并使用了交叉熵损失函数和“Adam”优化器。LSTM多层网络结构的定义和我们前面使用过的多层神经网络的定义方法一样,只是将“BasicRNNCell”类换成了“BasicLSTMCel”类。

652b0a09f3bdf5314b6530f69eb6dc82.png

在上面的整个代码中,我们使用的参数都是比较随意的进行选择的,没有进行任何的优化,最终在测试集上的结果能达到96%左右,当然这肯定不是LSTM网络处理mnist数据集所能达到的最好的效果,有兴趣的读者可以试着去调整网络的结构和参数,看是否能达到更高的准确率。

TensorFlow中实现LSTM和GRU的切换非常简单,在上面的代码中,将第22和26行代码注释掉,然后取消第24和27行代码的注释,实现的就是GRU。

本文介绍了门控循环神经网络LSTM以及GRU的原理及其tensorflow代码实现,希望能让大家对常用到的LSTM及GRU能够有更好的理解。下一篇,我们将介绍RNN循环神经网络的应用部分,分析RNN循环神经网络是怎样用在文本分类,序列标注以及机器翻译上的,以及其存在的不足与改进方法。

五、 参考文献

[1]Sepp Hochreiter: Long Short-term Memory .1997

[2]Kazuki Irie, Zoltán Tüske, TamerAlkhouli, Ralf Schlüter, Hermann Ney:

LSTM, GRU, Highway and a Bit of Attention:AnEmpirical Overview for Language Modeling in Speech Recognition.INTERSPEECH2016: 3519-3523

 
   

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇

 
   

3f1a97d227a96427aa838902bb5f5b19.jpeg


    
    
  1. 下载 1:OpenCV-Contrib扩展模块中文版教程
  2. 在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
  3. 下载 2:Python视觉实战项目 52
  4. 在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等 31个视觉实战项目,助力快速学校计算机视觉。
  5. 下载 3:OpenCV实战项目 20
  6. 在「小白学视觉」公众号后台回复:OpenCV实战项目 20讲,即可下载含有 20个基于OpenCV实现 20个实战项目,实现OpenCV学习进阶。
  7. 交流群
  8. 欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
@[TOC](这里写自定义目录标题)

标题

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值