LSTM逻辑设计详细解读

1:前言

之前在用LSTM做时序问题分类,如单变量预测、多变量预测、LSTM+CNN做时空卷积神经网络算法设计时,涉及算法调参过程时,对参数量和Num_Units的概念没有完全理解掌握,所以重新对LSTM自底向上重新梳理一遍。包括LSTM的设计原理,解决的主要问题,以及输入输出结果图。
备注:本篇博客主要是逻辑概念的梳理,不涉及到复杂数学推导,如反向传播计算,只有一些简单的线性矩阵运算。


2:目录

  • 2.1 RNN的介绍与应用于什么场景
  • 2.2 SimpleRNN的缺陷
  • 2.3 LSTM的设计逻辑
  • 2.4 LSTM各个门的激活函数
  • 2.5 LSTM实现的详细结构图
  • 2.6 相关引申问题
  • 2.7 参考文献及资料

2.1 RNN的介绍与应用于什么场景

在传统神经网络模型中,是从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的,这种网络结构对很多问题解决办法有限。循环神经网络RNN的目的使用来处理序列数据,例如输入一段词汇:“I arrive BeiJing in November” ,需要预测目的地和出发地,如果利用前向神经网络实现,我们将BeiJing这个单词经过词汇编码输入后,预测为目的概率预测值是最大。但后面假设在这个网络中又输入一段词汇:“I leave BeiJing in November”,由于输入单词都是BeiJing,因此目的地为BeiJing的概率最大,由第二个输入语句可以看到BeiJing实际上应该预测为出发地的概率最大。

而循环神经网络RNN具有记忆功能:当输入“I arrive BeiJing in November”时,在看到BeiJing这个词汇之前,已经看到并记忆arrive 这个词汇,将BeiJing预测为目的地的概率最大。当输入“I leave BeiJing in November”,在看到BeiJing这个词汇之前,已经看到并记忆leave这个词汇,因此将BeiJing预测为出发地的概率最大。

RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关,即在每个时刻做决策的时候都考虑一下上一个时刻的决策结果。理论上,RNN可以对任何长度的序列数据进行处理。在实践中,为了降低复杂性往往假设当前的状态只与前面的某几个状态相关。下图是一个展开的RNN结构示意图。
在这里插入图片描述
中间隐层节点,是RNN的memory (记忆)部分
图中可以看到每个时间点的决策会受前一时间点决策的影响

计算公式如下:
O t = f ( X ⋆ W + Y t − 1 ⋆ V + b ) \mathrm{O}_{\mathrm{t}}=\mathrm{f}\left(\mathrm{X}^{\star} \mathrm{W}+\mathrm{Y}_{\mathrm{t}-1}^{\star} \mathrm{V}+\mathrm{b}\right) Ot=f(XW+Yt1V+b),其中W、V、b是模型的参数,下标t代表当前的序列位置(时间点),t-1代表上个位置(上个时间点),X是当前的输入,f(·)是激活函数,*是矩阵乘法,O是模型输出。

上图是一个无隐藏层的循环神经网络,叫做“simple RNN”。是后续LSTM、GRU等门限循环网络的基础模型。


2.2 Simple RNN存在的问题

simple RNN虽然可以解决时序数据的依赖关系,但也存在问题:
比如在一个比较长的时序中(假设1000个),权重矩阵为W,在进行前向传播过程中,由于该权重矩阵W是共享的,当我们要输出t=1000时的数据,那么就是 w 1000 w^{1000} w1000,如果W大于1,则连乘1000此可能大的惊人,如果W小于1,则连乘1000此可能小的可以忽略。也许有人会说当W大于1时,我们将学习率调小一点,当W大于1时,将学习率调大一点?事实上,由于时序过长,无论调小或者调大,其最后计算结果都是一个极端值,下图是个例子:
在这里插入图片描述

如当W=1.01时,最后时序输出结果约为20000
在这里插入图片描述
如当W=0.99时,最后时序输出结果约为0
在这里插入图片描述
很容易看到,这是一个指数爆炸或指数衰减的过程。同样在计算反向传播的时候也是如此。误差反向传播过程可以近似看做输出层误差乘上倒数第t个时刻的梯度,此时跟前向传播结果类似,当更新靠前时刻的参数时,计算的梯度要么非常大,要么非常小,前面时刻的参数将停止更新。这就是梯度消失或者梯度爆炸。


针对Simple RNN存在的问题,我们可以想到那些解决办法?

Simple RNN在训练时,error loss 可能某段时间非常大,某段时间又非常小,是因为同样的权重矩阵W,在时间(序列)转换计算过程中,反复使用,如果一点权重矩阵W有影响,导致要么梯度非常大,要么梯度非常小,非常极端。

既然Simple RNN梯度变化非常大,那么是否可以设置一个阈值,当达到这个阈值时,梯度停止更新,以此避免发生影响模型之前积累的结果。但是,这样的阈值设定也有缺陷:当梯度达到阈值后,模型难以再学习很靠前的时序数据信息了,因为梯度停止更新。

因此Simple RNN在理论设计上可以保存任意长的时序数据来辅助当前时间点的决策,然而由于在训练模型时,梯度无法准确合理的传到很靠前的时间点(要么太大,要么太小),因此Simple RNN实际上只能记住并不是很长的时间序列信息数据。只是说传统前馈网络网络,可以记住的历史信息要长一些,但是Simple RNN无法记住长距离的信息。

针对Simple RNN的缺陷和设定阈值的启发思想,先引出接下来的Standard RNN,最后看LSTM是如何解决Simple RNN的缺陷问题


2.3:LSTM的设计逻辑

在Simple RNN模型中,每个时间点的决策受当前时刻和前一时间点的决策的综合影响,而我们很难看到Simple RNN与记忆具有某种交互联系。

我们回到神经网络模型刚开始设计出来的参照模板:神经网络模型本身就是模拟人脑结构的信息输入、记忆、输出,是生物脑神经结构各种功能在数学世界的投影。

假设我们现在正玩德州扑克,每一次我们出牌选择肯定会经过以下步骤:

  1. 查看现在手里现在准备要出的牌
  2. 回忆过去已经出牌的场景
  3. 综合1,2的信息,做出最合理的出牌决策

上面玩扑克过程中,第1步准备要出的牌就是当前时刻的外部输入X(t),第2步就是调用过去历史出牌时刻的信息(记忆),第3步就是结合1,2的综合信息来推理出当前的出牌结果,也就是当前时刻的输出Y(t)。

在上面的第2步过程中,回忆历史时刻出牌的信息时,并不是回忆过去具体出了那张牌,而是回忆一个大致模糊而包含大多数时刻出牌选择的场景。也就是说,在做时序任务模型的时候,并不是直接将上一时刻的输出结果Y(t-1)直接连接起来,而是连接一个模糊抽象的东西(记忆模块)。

这个记忆模块不就是神经网络中的隐层层节点信息嘛,也就是说,每一次出牌选择时,参考的是各个时刻的记忆模块。我们可以将这个出牌意识选择模型抽象为下面这样的:

在这里插入图片描述
上面加入隐藏层的循环神经网络模型就是经典的 RNN 神经网络模型,即“Standard RNN” (由此逐渐引出LSTM)。

根据以上德州扑克的案例,结合Simple RNN的缺陷,我们逐步开始设计一种可以解决梯度消失问题从而可以记住长距离依赖关系的神经网络结构。

(1) 信息传输的问题
首先需要解决问题是如何防止让梯度随着时间增加发生梯度消失或者梯度爆炸的现象?---- 最简单的思路,即让梯度恒等于1,也就是说无论时间怎么变化,前后记忆中的内容是一样的,记忆信息因此可以没有任何损耗一直可以传输下去,网络前端与网络末端的远距离依赖关系也可以学到。我们把需要设计的记忆单元记为c,根据这个思路,记忆单元的数学表达即为: c ( t ) = c ( t − 1 ) c(t)=c(t-1) c(t)=c(t1)。OK,梯度消失的问题,通过这个思路可以解决了。

(2) 信息装载问题----如何将新的信息输入记忆单元C中
根据simple RNN中对新信息的定义:当前时刻的外部输入x(t)与前一时刻的网络输出(即反馈单元)y(t-1)联合得到网络在当前这一时的新信息,记为 C ^ ( t ) \hat{C}(t) C^(t),表达如下:

c ^ ( t ) = f ( W ⋅ x ( t ) + V ∗ ⋅ y ( t − 1 ) ) \hat{c}(t)=f(W \cdot x(t)+V * \cdot y(t-1)) c^(t)=f(Wx(t)+Vy(t1))

那如何将当前时刻的新信息加入到记忆单元c中呢?就好比在德州扑克例子中,假设当前时刻 t 已经出牌了,那么如何将这个出牌信息加入到我们的记忆模块,以便后续出牌再次进行参考呢?
可能想到的方法有以下两种:

  1. 乘进去:将当前时刻得到的新信息 C ^ ( t ) \hat{C}(t) C^(t)与前一时刻记忆单元中的结果 c ( t ) c(t) c(t)相乘,即: c ( t ) = c ( t − 1 ) ∗ c ^ ( t ) c(t)=c(t-1) * \hat{c}(t) c(t)=c(t1)c^(t)
  2. 加进去:将当前时刻得到的新信息 C ^ ( t ) \hat{C}(t) C^(t)与前一时刻记忆单元中的结果 c ( t ) c(t) c(t)相加,即: c ( t ) = c ( t − 1 ) + c ^ ( t ) c(t)=c(t-1) + \hat{c}(t) c(t)=c(t1)+c^(t)

实际上,循环神经网络采用的是加法运算逻辑。如果是做乘法,那么梯度爆炸和梯度消失的问题将会更加明显(可以参考相关资料里的数学证明)。加法更适合做信息叠加,乘法更适合做控制和scaling。可以再联想到脑结构记忆容量,本身就是各个时刻的记忆累加。如果是乘法的话,那么我们大脑要么变成过目不忘,要么就变成失忆了。但有人会继续追问,按照这样理解,如果换成加法的话,每个时刻大脑都在装载新的信息,随着时间推移,我们记忆容量也会越来越多,不也变成过目不忘吗?先记住这个待解决的问题,下面介绍的门控制逻辑会针对此问题作出设计

(3) 输入门控制逻辑----降低梯度消失的可能性
在信息传输和信息装载同时存在的情况下,我们如何进一步让梯度消失的可能性降低呢?
仔细想想,大脑记忆结构并不是每时每刻都在添加新的信息,因为现实生活中只有很少的时刻,我们是可以长期记忆的,大部分记忆,可能没几天我们都忘记了。因此在(2)信息装载问题中, c ( t ) = c ( t − 1 ) + c ^ ( t ) c(t)=c(t-1) + \hat{c}(t) c(t)=c(t1)+c^(t),这个数学模型视图要记住每个时刻的信息显然是不合理的,我们只需要记住该记住的信息。
上面说到,乘法更适合做控制和scaling,对新信息选择要不要记忆是一个控制逻辑 ,所以应该用乘法规则。即在前面加一个乘法控制阀门: c ( t ) = c ( t − 1 ) + g i n ∗ c ^ ( t ) c(t)=c(t-1)+g_{i n} * \hat{c}(t) c(t)=c(t1)+ginc^(t) g i n g_{i n} gin就是输入门,取值范围为0~1,表示需要是否需要增加新信息,很容易想到用sigmoid函数作为输入门的激活函数,因为sigmoid的输出范围是0到1之间。

上面是对一个长时记忆单元的控制。而大脑有很多歌记忆神经元,我们需要设置更多的记忆单元的,每个长时记忆单元都有它专属的输入门,在数学上我们不妨使用来表示这个按位相乘的操作,用大写字母 ⊗ \otimes 来表示长时记忆单元集合。即: C ( t ) = C ( t − 1 ) + g i n ⊗ C ^ ( t ) C(t)=C(t-1)+g_{i n} \otimes \hat{C}(t) C(t)=C(t1)+ginC^(t),正如我们只需要记住该记住的时刻新信息,因此大部分输入门只会在必要的时刻为开启状态,也就是说大部分时刻下, C ( t ) = C ( t − 1 ) C(t)=C(t-1) C(t)=C(t1),这样加法操作带来的梯度爆炸或梯度消失的可能性更低了。

(4)遗忘门控制逻----降低记忆单元c的饱和度
前面我们说到,我们只需要记住该记住的时刻信息就好了,但这里有一个问题:如果在某个时刻,该网络结构在输入一些信息量很大的数据时,导致输入门始终处于开启状态,视图记住所有的这些信息,会导致什么结果?(也就是(2)中我们保留的一个问题

肯定会导致脑容量达到峰值,也就是记忆单元c的值会非常大。而在网络输出的时候,我们是需要把c激活的,当c变得非常大时,sigmoid、tanh这些常见的激活函数的输出就趋于完全饱和。也就是说,我们的脑容量达到峰值,快记不住这么多的信息了。

想象我们的脑结构,我们之所以记不住太多的信息,觉得大脑容量有限,原因就是因为容易忘记一些其他的事,因此可以设计以个遗忘门,当新信息要输入的时候,先通过遗忘门来忘记一些记忆,从而再考虑要不要接受该时刻的新信息。

显然, 遗忘门来控制好记忆消失程度的,也需要用乘法运算,到目前,我们设计的网络变成: c ( t ) = g forget  c ( t − 1 ) + g in  ∗ c ^ ( t ) c(t)=g_{\text {forget }} c(t-1)+g_{\text {in }} * \hat{c}(t) c(t)=gforget c(t1)+gin c^(t),或者向量形式: C ( t ) = g forget  C ( t − 1 ) + g in ⊗ C ^ ( t ) C(t)=g_{\text {forget }} C(t-1)+g_{\text {in}} \otimes \hat{C}(t) C(t)=gforget C(t1)+ginC^(t)

至此,我们已经解决了如何避免梯度消失的问题,也考虑到如何将新信息加入到记忆单元,并且还考虑到新信息输入太过丰富而可能导致输入门始终开启的状况。只剩最后一步:考虑记忆单元的输出问题了

(5)输出门控制逻辑----输出跟当前任务有关时刻相关记忆单元
前面的Simple RNN 提到,输出步骤就是激活当前记忆单元的内容: y ( t ) = f ( c ( t ) ) y(t)=f(c(t)) y(t)=f(c(t))

想一想,现实生活中,我们在处理目前的事情时,需要将我们脑容量中的所有记忆单元中的内容都要输出一遍吗?肯定不是,我们只需要跟目前处理的事情任务相关的记忆单元输出结果就行。类似于遗忘门设计的思想,我们也应该给记忆单元添加一个阀门: y ( t ) = g o u t ∗ f ( c ( t ) ) y(t)=g_{o u t} * f(c(t)) y(t)=goutf(c(t))

(6)输入门、遗忘门、输出门的控制
在最后,我们只需要定义输入门、遗忘门、输出门受谁的控制。
按照Simple RNN的思想:是否可以让各个门受当前时刻的外部输入x(t)和上一时刻的输出y(t-1)的综合影响影响就行了?

想一想,为了解决Simple RNN存在的问题,在设计这个新的网络时,引入了3个门控制逻辑,特别是输出门,如果按照Simple RNN的控制逻辑,当输出门一旦关闭,就好导致后面时序中的记忆全部被截断,下一时刻的各个门的输入仅仅受当前外部输入 x ( t ) x(t) x(t)了。

我们可以将长时记忆单元接入各个门:把上一时刻的长时记忆 c ( t − 1 ) c(t-1) c(t1)接入遗忘门和输入门,把当前时刻的长时记忆 x ( t ) x(t) x(t)接入输出门(当信息流动到输出门的时候,当前时刻的长时记忆已经被计算完成了)

计算公式如下:
g i n ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t − 1 ) ) g_{i n}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t-1)) gin(t)=sigm(Wx(t)+Vy(t1)+Uc(t1))
g forget  ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t − 1 ) ) g_{\text {forget }}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t-1)) gforget (t)=sigm(Wx(t)+Vy(t1)+Uc(t1))
g out ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t ) ) g_{\text {out}}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t)) gout(t)=sigm(Wx(t)+Vy(t1)+Uc(t))

到目前,我们的网络结构计算总结如下:
C ( t ) = g forget  C ( t − 1 ) + g in ⊗ C ^ ( t ) C(t)=g_{\text {forget }} C(t-1)+g_{\text {in}} \otimes \hat{C}(t) C(t)=gforget C(t1)+ginC^(t)
C ^ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \hat{C}(t)=f(W \cdot x(t)+V \cdot y(t-1)) C^(t)=f(Wx(t)+Vy(t1))
y ( t ) = g out ⊗ f ( C ( t ) ) y(t)=g_{\text {out}} \otimes f(C(t)) y(t)=goutf(C(t))
g i n ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{i n}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t-1)) gin(t)=sigm(Wx(t)+Vy(t1)+UC(t1))
g forget ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{\text {forget}}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t-1)) gforget(t)=sigm(Wx(t)+Vy(t1)+UC(t1))
g o u t ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t ) ) g_{o u t}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t)) gout(t)=sigm(Wx(t)+Vy(t1)+UC(t))

其中W、V、U 分别表示输入门输出门遗忘门的权重矩阵,上面公式看着一大堆,其实就是三个门的控制逻辑运算和输入、输出计算。


在simple RNN 的基础上,我们加入了隐藏层,由此过渡到standard RNN,模仿standard RNN的做法,用隐层单元h来替换结果输出y。

C ( t ) = g forget C ( t − 1 ) + g in ⊗ C ^ ( t ) C(t)=g_{\text {forget}} C(t-1)+g_{\text {in}} \otimes \hat{C}(t) C(t)=gforgetC(t1)+ginC^(t)
C ^ ( t ) = f ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) ) \hat{C}(t)=f(W \cdot x(t)+V \cdot h(t-1)) C^(t)=f(Wx(t)+Vh(t1))
y ( t ) = g o u t ⊗ f ( C ( t ) ) y(t)=g_{o u t} \otimes f(C(t)) y(t)=goutf(C(t))
g i n ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{i n}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot h(t-1)+U \cdot C(t-1)) gin(t)=sigm(Wx(t)+Vh(t1)+UC(t1))
g f o r g e t ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{f o r g e t}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot h(t-1)+U \cdot C(t-1)) gforget(t)=sigm(Wx(t)+Vh(t1)+UC(t1))
g o u t ( t ) = sigm ⁡ ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t ) ) g_{o u t}(t)=\operatorname{sigm}(W \cdot x(t)+V \cdot h(t-1)+U \cdot C(t)) gout(t)=sigm(Wx(t)+Vh(t1)+UC(t))


2.4 各个门的激活函数

关于激活函数的选取,LSTM中,遗忘门、输入门和输出门使用Sigmoid函数作为激活函数,在生成候选记忆时,使用Tanh双曲正切函数作为激活函数。这两个激活函数都是饱和的,即在输入达到一定值得情况下,输出不会发生变化。因为如果使用非饱和激活函数,,比如ReLu函数,那么将难以实现门控的效果。

Sigmoid函数输出值在0 ~ 1之间,符合门控的物理意义,且当输入较大或者较小时,其输出会非常接近1或0,从而保证开和关的逻辑设计。在产生候选记忆时,使用Tanh函数,是因为输出在-1 ~ 1之间,这与大多数场景下特征分布是0中心吻合。此外Tanh函数在输入为0附近比Sigmoid函数具有更大的梯度,通常使模型收敛更快。


2.5 下图就是LSTM详细网络结构图

在这里插入图片描述
短时记忆单元:
由于h随时都可以被输出门截断,所以我们可以很感性的把h理解为短时记忆单元。

长时记忆单元:
由于梯度只从c走的时候,存在一条没有连续相乘的路径,可以避免梯度消失。又有遗忘门避免激活函数和梯度饱和,因此c为长时记忆单元

全文总结:

  1. 为了解决RNN中的梯度消失的问题,为了让梯度无损传播,想到了c(t)=c(t-1)这个朴素梯度传播模型,所以将c称为“长时记忆单元”。
  2. 然后为了把新信息平稳安全可靠的装入长时记忆单元,我们引入了“输入门”。
  3. 然后为了解决新信息装载次数过多带来的激活函数饱和的问题,引入了“遗忘门”。
  4. 然后为了让网络能够选择合适的记忆进行输出,我们引入了“输出门”。
  5. 然后为了解决记忆被输出门截断后使得各个门单元受控性降低的问题,引入了“peephole”连接。
  6. 然后为了将神经网络的简单反馈结构升级成模糊历史记忆的结构,引入了隐单元h,并且发现h中存储的模糊历史记忆是短时的,于是记h为短时记忆单元。
  7. 于是该网络既具备长时记忆,又具备短时记忆,起名叫“长短时记忆神经网络(Long Short Term Memory Neural Networks,简称LSTM)“。

2.6 相关引申问题

LSTM为什么在梯度消失上处理的更好?

继续以李宏毅老师的课件为例:
在这里插入图片描述
每个时间段t,RNN读入一个变量 x t , x t x_{t}, x_{t} xt,xt和上一个阶段产生的隐含层的信息 h t − 1 h_{t-1} ht1一起被写入计算模块 f ( ⋅ ) f(\cdot) f()里面,产生新的隐含层信息 h t h_{t} ht,新的隐含层信息一方面继续传播更新,另一方面产生该时刻的输出 y t y_{t} yt。隐含层信息 h t h_{t} ht和所有戒指时间t的历史输入都有关系,并依赖RNN的信息路径一直传播下去。

RNN的细节表示,用公式表示x,h,y的关系有:
在这里插入图片描述
不考虑偏置的情况下, W h W^{h} Wh是旧的隐含层信息对新的隐含层信息的影响, W i W^{i} Wi是输入信息新的隐含层信息的影响, W h W^{h} Wh W i W^{i} Wi, W o W^{o} Wo都是要学习的参数。

定义损失函数,采用梯度下降求解参数:
L = ∑ t = 0 T L t L=\sum_{t=0}^{T} L_{t} L=t=0TLt

损失函数考虑了所有时间t的误差,利用BPTT反向传播求解参数模型,对上述两公式求梯度:

∂ L ∂ W = ∑ t = 0 T ∂ L t ∂ W \frac{\partial L}{\partial W}=\sum_{t=0}^{T} \frac{\partial L_{t}}{\partial W} WL=t=0TWLt

L t L_{t} Lt分别对 W h , W i , W o W^{h}, W^{i}, W^{o} Wh,Wi,Wo进行求导:

∂ L t ∂ W o = ∑ t = 0 T ∂ L t ∂ y t ∂ y t ∂ W o \frac{\partial L_{t}}{\partial W^{o}}=\sum_{t=0}^{T} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial W^{o}} WoLt=t=0TytLtWoyt (1)
∂ L t ∂ W i = ∑ t = 0 T ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W i \frac{\partial L_{t}}{\partial W^{i}}=\sum_{t=0}^{T} \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial h_{t}}\left(\prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}\right) \frac{\partial h_{k}}{\partial W^{i}} WiLt=t=0Tk=0tytLthtyt(j=k+1thj1hj)Wihk (2)
∂ L t ∂ W h = ∑ t = 0 T ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W h \frac{\partial L_{t}}{\partial W^{h}}=\sum_{t=0}^{T} \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial h_{t}}\left(\prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}\right) \frac{\partial h_{k}}{\partial W^{h}} WhLt=t=0Tk=0tytLthtyt(j=k+1thj1hj)Whhk (3)

梯度消失主要就是针对上面(2)和(3)两个式子,可以看到,上面公式里面有依赖于时间 t 的连乘符号,在修正某个位置 t 的误差时,计算出的梯度需要考虑 t 之前的所有时间 k 的隐含层信息对时间 t 的隐含层信息的影响。当 k 和 t 越远时,这个影响被迭代的次数就越多,对应着隐含层之间的连乘次数就越多。于是产生梯度消失,实际上梯度爆炸也是这个原因导致的!

进一步,上述连乘步骤是如何发生作用的?

首先,根据RNN的定义,把隐含层之间的函数关系表示出来,具体有:
h t = σ ( W i x t + W h h t − 1 ) h_{t}=\sigma\left(W^{i} x_{t}+W^{h} h_{t-1}\right) ht=σ(Wixt+Whht1)

其中, σ \sigma σ 表示Sigmoid激活函数,于是有:

∂ h j ∂ h j − 1 = σ ′ W h \frac{\partial h_{j}}{\partial h_{j-1}}=\sigma^{\prime} W^{h} hj1hj=σWh

根据Sigmoid的特性,存在关系: σ ′ = σ ( 1 − σ ) σ ∈ ( 0 , 1 ) \sigma^{\prime}=\sigma(1-\sigma) \quad \sigma \in(0,1) σ=σ(1σ)σ(0,1)

显而易见, σ ′ \sigma^{\prime} σ存在上界,为 1 4 \frac{1}{4} 41 ,那么:
(1) W h W^{h} Wh > 4时, σ ′ W h \sigma^{\prime} W^{h} σWh 一直大于1,因此无论参数如何取值,当 j,k 距离很大时,连乘项都会趋向于无穷,在这种情况下就会导致梯度爆炸;
(2) W h W^{h} Wh < 1时, σ ′ W h \sigma^{\prime} W^{h} σWh一直小于1,因此无论参数如何取值,当 [公式] 距离很大时,连乘项都会趋于0,在这种情况下就会导致梯度消失。

LSTM神经网络输入输出究竟是怎样的?

在这里插入图片描述

2.7 隐藏层cell中的参数

在这里插入图片描述
结合上面讲述的LSTM的输入输出,接下来继续hidden中的cell参数是怎么来的?
上下图的结构就是一个LSTM单元,里面的每个黄框是一个神经网络,这个网络的隐藏单元个数我们设为hidden_size,那么这个LSTM单元里就有4*hidden_size个隐藏单元。
每个LSTM输出的都是向量,包括 C t , h t C_{t}, h_{t} Ct,ht,,它们的长度都是当前LSTM单元的hidden_size。

用LSTMBlockCell构造了一个LSTM单元,单元里的隐藏单元个数是hidden_size,有四个神经网络,每个神经网络的输入是 h t − 1 h_{t-1} ht1 x t x_{t} xt
,将它们concat到一起,维度为(hidden_size+wordvec_size),所以LSTM里的每个黄框的参数矩阵的维度为
[hidden_size+wordvec_size,hidden_size]

需要注意的是,num_steps(步长)个时刻的LSTM都是共享一套参数的,说是有num_steps个LSTM单元,其实只有一个,只不过是对这个单元执行num_steps次。

例子

举个例子,比如一批训练64句话,每句话20个单词,每个词向量长度为200,隐藏层单元个数为128

那么训练一批句子,输入的张量维度是[64,20,200], h t , c t h_{t}, c_{t} ht,ct

的维度是[128],那么LSTM单元参数矩阵的维度是[128+200,4*128],

在时刻1,把64句话的第一个单词作为输入,即输入一个[64,200]的矩阵,由于会和 h t h_{t} ht 进行concat,输入矩阵变成了[64,200+128],输入矩阵会和参数矩阵[200+128,4 x 128]相乘,输出为[64,4 x 128],也就是每个黄框的输出为[64,128],黄框之间会进行一些操作,但不改变维度,输出依旧是[64,128],即每个句子经过LSTM单元后,输出的维度是128,所以上一章节的每个LSTM输出的都是向量,包括C_t,h_t,它们的长度都是当前LSTM单元的hidden_size 得到了解释。那么我们就知道cell_output的维度为[64,128]

之后的时刻重复刚才同样的操作,那么outputs的维度是[20,64,128].

softmax相当于全连接层,将outputs映射到vocab_size个单词上,进行交叉熵误差计算。

然后根据误差更新LSTM参数矩阵和全连接层的参数

2.8参考资料

  1. 书籍:《深度学习》、《百面机器学习》
  2. 知乎《LSTM神经网络输入输出》、《为什么相比于RNN,LSTM在梯度消失上表现更好》
  • 45
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值