深度学习笔记(九):LSTM学习笔记(结构解析,从RNN的发展历程,解决梯度爆炸和梯度消失,长短时间记忆的含义)

零、前置知识

可以先看一下如下前置链接

循环神经网络(RNN)的解释说明及其梯度爆炸或消失的tricks

从最简单的前馈神经网络引出了简单的循环神经网络,起名叫 “simple RNN” 。这种方式即在每个时刻做决策的时候都考虑一下上一个时刻的决策结果。如图所示:

在这里插入图片描述

这张图中圆球里的下半球代表两向量的内积,上半球代表将内积结果激活

虽然通过这种简单反馈确实可以看出每个时间点的决策会受前一时间点决策的影响,但是对记忆力的理解还是不够清晰。

RNN的一步前向过程的输出符合一个人类的记忆过程根据外部的输入进行判断,调用之前的决策输出历史信息(之前的输入所获得的状态,这个状态不仅仅是上面一个状态,而是长时间的概括集合),结合推理出决策结果。

但是人们在做很多时序任务的时候,尤其是稍微复杂的时序任务时,潜意识的做法并不是直接将上个时刻的结果输出y(t-1)直接连接进来,而是连接一个模糊而抽象的东西进来!这个东西是神经网络中的隐结点S。也就是说,人们潜意识里直接利用的是一段历史记忆融合后的东西 S S S,而不单单是上一时间点的输出。而网络的输出则取自这个隐结点。所以更合理的刻画人的潜意识的模型应该是这样的:
在这里插入图片描述

O t = f o ( V ∗ S + b o ) S t = f s ( U ∗ X t + W ∗ S t − 1 + b s ) O_t = f_o(V*S+b_o)\\ S_t = f_s(U*X_t+W*S_{t-1}+b_s) Ot=fo(VS+bo)St=fs(UXt+WSt1+bs)
记忆在隐单元中存储和流动,输出取自隐单元,这种加入了隐藏层的循环神经网络就是经典的RNN神经网络,即 “standard RNN”

RNN从simple到standard的变动及其意义对于本文后续内容非常重要哦。

(简单到标准的过程涉及到LSTM的发展思路所以在这篇blog中进行介绍)

上一篇blog中也有介绍,RNN的记忆单元是短时的。 这就是LSTM发展的历程了。

一、LSTM目标

可以解决梯度爆炸消失问题,从而记住长距离依赖关系,让记住长期信息成为神经网络的默认行为,而不是需要很大力气才能学会。

二、LSTM的结构解析

在这里插入图片描述

LSTM的关键是单元状态 C C C(cell state),即下图中LSTM单元上方从左贯穿到右的水平线,它像是传送带一样,主要功能是将信息从上一个单元传递到下一个单元,它和其他部分只有很少的线性的相互作用。
在这里插入图片描述

门结构的介绍

LSTM通过“门”(gate)来控制丢弃或者增加信息,从而实现遗忘或记忆的功能。“门”是一种使信息选择性通过的结构,由一个sigmoid函数( σ ( ⋅ ) \sigma(\cdot) σ())和一个点乘操作组成。sigmoid函数的输出值在[0,1]区间,0代表完全丢弃,1代表完全通过。一个LSTM单元有三个这样的门,分别是遗忘门(forget gate)、输入门(input gate)、输出门(output gate)。对应的控制信号分别表示为 g f o r g e t , g i n p u t , g o u t p u t g_{forget},g_{input},g_{output} gforget,ginput,goutput

具体操作可以概括为输入信号 I I I,被控制信号 O O O,输入信号经过黄色框的神经网络之后通过sigmoid函数激活得到一个控制信号 g ∗ g_* g,这也就是门的输出,用这个信号去调控之前的被控制信号 O o r i g i n a l O_{original} Ooriginal,得到新的信号 O n e w O_{new} Onew,具体的控制方式应该是按位相乘(自己的想法)或点乘(点乘是因为每次输入的是一个值输出的也是一个值)。
g ∗ = σ ( W ⋅ I ) O n e w = g ∗ ⋅ O o r i g i n a l \begin{aligned} g_* &= \sigma(W \cdot I)\\ O_{new} &= g_* \cdot O_{original} \end{aligned} gOnew=σ(WI)=gOoriginal
在这里插入图片描述

遗忘门(forget gate)

输入: 上一单元的输出 h t − 1 h_{t-1} ht1和本单元的输入 x t x_t xt

输出: 用来控制上一单元状态信息遗忘程度的控制信号 f t f_t ft

主要想法与目的: 通过本时间点的输入和上时间点的决策产生一个控制信号,控制信号为[0,1]内的值作为权重系数与上一个时间点状态分量 C t − 1 C_{t-1} Ct1中的存储值相乘,达到控制上一单元状态被遗忘的程度的目的。

结构图:

在这里插入图片描述

公式表达:
控制信号的产生:  f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) W f , b f : 遗 忘 门 所 对 应 的 权 重 矩 阵 和 偏 置 矩 阵 \begin{aligned} \text{控制信号的产生: } f_t &= \sigma(W_f \cdot[h_{t-1},x_t]+b_f )\\ W_f,b_f:& 遗忘门所对应的权重矩阵和偏置矩阵\\ \end{aligned} 控制信号的产生ftWf,bf:=σ(Wf[ht1,xt]+bf)

输入门(input gate)

输入: 上一单元的输出 h t − 1 h_{t-1} ht1和本单元的输入 x t x_t xt

输出: 用来控制本单元信息以何种形式加入的控制信号 i t i_t it

主要想法与目的: 首先将本次时间节点的输入信息和上次的决策信息归纳成新的候选状态信息 C ~ t \tilde{C}_{t} C~t,这个候选信息需要被选择性的加入的,不可以直接将本次输入对应的信息会被传递进入历史信息(可能会出现本次信息量较大,因为只能被选择性的控制加入);接着通过产生存储值为[0,1]的控制信号作为权重系数与候选状态信息 C ~ t \tilde{C}_{t} C~t相乘,达到控制新信息加入的目的。

结构图:

公式表达:
控制信号的产生:  i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) 候选状态信息的产生:  C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) W i , b i : 输 入 门 所 对 应 的 权 重 矩 阵 和 偏 置 矩 阵 W C , b C : 生 成 候 选 加 入 信 息 所 对 应 的 权 重 矩 阵 和 偏 置 矩 阵 完成状态信息的更新:  C t = f t ∗ C t − 1 + i t ∗ C ~ t \begin{aligned} \text{控制信号的产生: }i_{t} & =\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right) \\ \text{候选状态信息的产生: }\tilde{C}_{t} &=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)\\ W_i,b_i:& 输入门所对应的权重矩阵和偏置矩阵\\ W_C,b_C:& 生成候选加入信息所对应的权重矩阵和偏置矩阵\\ \text{完成状态信息的更新: }C_{t}&=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t} \end{aligned} 控制信号的产生it候选状态信息的产生C~tWi,bi:WC,bC:完成状态信息的更新Ct=σ(Wi[ht1,xt]+bi)=tanh(WC[ht1,xt]+bC)=ftCt1+itC~t

输出门(output gate)

输入: 上一单元的输出 h t − 1 h_{t-1} ht1和本单元的输入 x t x_t xt

输出: 用来控制当前的单元状态有多少被过滤掉的控制信号 o t o_t ot

主要想法与目的: 控制当前的单元状态有多少被过滤掉。先将单元状态激活,输出门为其中每一项产生一个在[0,1]内的值,控制单元 h t h_t ht被过滤传到下一级的程度。

结构图:

在这里插入图片描述

公式表达:
控制信号的产生:  o t = σ ( W o [ h t − 1 , x t ] + b o ) 最终输出决策的产生:  h t = o t ∗ tanh ⁡ ( C t ) W o , b o : 生 成 候 选 加 入 信 息 所 对 应 的 权 重 矩 阵 和 偏 置 矩 阵 \begin{aligned} \text{控制信号的产生: }o_{t}&=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right) \\ \text{最终输出决策的产生: }h_{t}&=o_{t} * \tanh \left(C_{t}\right)\\ W_o,b_o&: 生成候选加入信息所对应的权重矩阵和偏置矩阵\\ \end{aligned} 控制信号的产生ot最终输出决策的产生htWo,bo=σ(Wo[ht1,xt]+bo)=ottanh(Ct):

总结一下前馈结构流程

再看一眼结构图
在这里插入图片描述

1、 g f o r g e t g_{forget} gforget受当前时刻的外部输入 x ( t ) x(t) x(t)、上一时刻的输出(短时记忆) h ( t − 1 ) h(t-1) h(t1)的控制(也可添加上一时刻的长时记忆 C ( t − 1 ) C(t-1) C(t1)作为控制门的输入)。

2、 g i n g_{in} gin受当前时刻的外部输入 x ( t ) x(t) x(t)、上一时刻的输出(短时记忆) h ( t − 1 ) h(t-1) h(t1)的控制(也可添加上一时刻的长时记忆 C ( t − 1 ) C(t-1) C(t1)作为控制门的输入)。

3、由当前时刻外部输入 x ( t ) x(t) x(t)和上一时刻的短时记忆 h ( t − 1 ) h(t-1) h(t1)计算出当前时刻的新信息 C ~ t \tilde{C}_{t} C~t

4、然后由遗忘门 g f o r g e t g_{forget} gforget控制上一单元长时记忆单元 C ( t − 1 ) C(t-1) C(t1)去遗忘一些信息 g f o r g e t g_{forget} gforget,输入门 g i n g_{in} gin控制当前时刻的部分新信息 C ~ t \tilde{C}_{t} C~t选择性载入长时记忆单元,产生新的长时记忆 C ( t ) C(t) C(t),可以传输到下一时间单元。

5、 g o u t g_{out} gout受当前时刻的外部输入 x ( t ) x(t) x(t)、上一时刻的输出(短时记忆) h ( t − 1 ) h(t-1) h(t1)的控制(也可添加上一时刻的长时记忆 C ( t ) C(t) C(t)作为控制门的输入)。

6、将至目前积累下来的记忆 C ( t ) C(t) C(t)激活,然后在输出门 g o u t g_{out} gout把控下,选出部分作为这一时刻我们关注的记忆(此处即为短时间记忆),也可以理解为当前时间的决策信息 h ( t ) h(t) h(t),再把这部分记忆进行输出。
遗忘门控制:  f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) 输入门控制:  i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) 新信息产生:  C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) 状态信息的更新:  C t = f t ∗ C t − 1 + i t ∗ C ~ t 输出门控制:  o t = σ ( W o [ h t − 1 , x t ] + b o ) 当前时间决策的产生:  h t = o t ∗ tanh ⁡ ( C t ) \begin{aligned} \text{遗忘门控制: } f_t &= \sigma(W_f \cdot[h_{t-1},x_t]+b_f )\\ \text{输入门控制: }i_{t} & =\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right)\\ \text{新信息产生: }\tilde{C}_{t} &=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)\\ \text{状态信息的更新: }C_{t}&=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t}\\ \text{输出门控制: }o_{t}&=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right) \\ \text{当前时间决策的产生: }h_{t}&=o_{t} * \tanh \left(C_{t}\right)\\ \end{aligned} 遗忘门控制ft输入门控制it新信息产生C~t状态信息的更新Ct输出门控制ot当前时间决策的产生ht=σ(Wf[ht1,xt]+bf)=σ(Wi[ht1,xt]+bi)=tanh(WC[ht1,xt]+bC)=ftCt1+itC~t=σ(Wo[ht1,xt]+bo)=ottanh(Ct)

三、LSTM变种

上面描述的LSTM是一种标准版本,并不是所有LSTM都和上面描述的一模一样。事实上,似乎每篇论文用到的LSTM都有一点细微的不同。

一种比较流行的LSTM变种如下图所示,最早由Gers & Schmidhuber在2000年提出。这种方法增加了“peephole connections”,即每个门都可以“窥探”到单元状态。这里,遗忘门和输入门是与上一单元状态建立连接,而输出门是与当前单元状态建立连接。

f t = σ ( W f ⋅ [ C t − 1 , h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ C t − 1 , h t − 1 , x t ] + b i ) o t = σ ( W o ⋅ [ C t , h t − 1 , x t ] + b o ) \begin{aligned} f_{t} &=\sigma\left(W_{f} \cdot\left[C_{t-1}, h_{t-1}, x_{t}\right]+b_{f}\right) \\ i_{t} &=\sigma\left(W_{i} \cdot\left[C_{t-1}, h_{t-1}, x_{t}\right]+b_{i}\right) \\ o_{t} &=\sigma\left(W_{o} \cdot\left[C_{t}, h_{t-1}, x_{t}\right]+b_{o}\right) \end{aligned} ftitot=σ(Wf[Ct1,ht1,xt]+bf)=σ(Wi[Ct1,ht1,xt]+bi)=σ(Wo[Ct,ht1,xt]+bo)
有一个变种取消了输入门,将新信息加入的多少与旧状态保留的多少设为互补的两个值(和为1),即:只有当需要加入新信息时,我们才会去遗忘;只有当需要遗忘时,我们才会加入新信息。
在这里插入图片描述

C t = f t ∗ C t − 1 + ( 1 − f t ) ∗ C ~ t C_{t}=f_{t} * C_{t-1}+\left(1-f_{t}\right) * \tilde{C}_{t} Ct=ftCt1+(1ft)C~t
另外一个值得关注的变种看起来很好玩,叫做Gated Recurrent Unit(GRU),最早由Cho, et al.在2014年提出。这种方法将遗忘门和输入门连入了一个“更新门”(update gate),同时也合并了隐藏状态 h t h_t ht和单元状态 C t C_t Ct,最终的结果比标准LSTM简单一些。

四、解决问题的思路(从循环神经网络到LSTM)

问题一:解决随时间的流动梯度发生的指数级消失或者爆炸的情况

解决方法: 即迫使算出来的梯度恒为1。因为1的任何次方都是1。

设计的长时记忆单元记为 C e l l Cell Cell,那么我们设计出来的长时记忆单元的数学模型就是这样子:
C e l l ( t ) = C e l l ( t − 1 ) Cell(t)= Cell(t-1) Cell(t)=Cell(t1)
这样的话,误差反向传播时的导数就恒定为1。误差可以一路无损耗的向前传播到网络的前端,从而学习到遥远的前端与网络末端的远距离依赖关系。

问题二:将信息装入长时记忆单元

解决方法: 问题一的解决方式是建立在 C e l l Cell Cell中存储了历史信息信息 C C C的基础之上,那么 C e l l Cell Cell就能把当前信息和历史信息都带到输出层。在 t t t时刻算出来的梯度信息存储在 C e l l ( t ) Cell(t) Cell(t)里后,它也能把梯度一路带到时刻 t = 0 t=0 t=0而无任何损耗。如果我们能把信息包装进 C e l l Cell Cell里面, C e l l Cell Cell就可以解决信息存储和流动的问题。

问题转变为新信息如何产生概括,如何加载进入旧的信息。

其实类比RNN中的当前时刻的外部输入 x ( t ) x(t) x(t)与前一时刻的网络输出​联合起来,作为输入,适当操作(经过网络或者激活等等)后,定义为整个结构在当前这一时刻get到的新信息,这个信息呢定义为本节点的信息,至于怎么加载进入就的信息让它传递下去等会再考虑,记为 C ~ ( t ) \tilde{C}(t) C~(t)。即:
C ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \tilde{C}(t)=f(W \cdot x(t)+V \cdot y(t-1)) C~(t)=f(Wx(t)+Vy(t1))
这个公式是针对于RNN的,LSTM使用的是 h ( t − 1 ) h(t-1) h(t1),并且没有单独设置权重矩阵,不过这些问题都不大,我们需要理解的知识思路过程。

有了新信息 C ~ ( t ) \tilde{C}(t) C~(t)如何融合进旧的信息 C ( t − 1 ) {C}(t-1) C(t1),从而形成新的信息$ C(t)$传递到以后的单元中,有两种方案:乘法和加法。

其实稍微一想就很容易判断:乘法操作更多的是作为一种对信息进行某种控制的操作(比如任意数与0相乘后直接消失,相当于关闭操作;任意数与大于1的数相乘后会被放大规模等),而加法操作则是新信息叠加旧信息的操作。

下面我们深入的讨论一下乘性操作加性操作,这在理解LSTM里至关重要。

论乘法:

设置记忆信息添加数学模型为:
C ( t ) = C ( t − 1 ) ⋅ C ~ ( t ) C(t)=C(t-1) \cdot \tilde{C}(t) C(t)=C(t1)C~(t)
因此网络完整数学模型如下:
C ~ t ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) C ( t ) = C ( t − 1 ) ⋅ C ~ ( t ) y ( t ) = f ( C ( t ) ) \begin{aligned} \tilde{C}_{t}(t) &=f(W \cdot x(t)+V \cdot y(t-1)) \\ C(t)&=C(t-1) \cdot \tilde{C}(t) \\ y(t)&=f(C(t)) \end{aligned} C~t(t)C(t)y(t)=f(Wx(t)+Vy(t1))=C(t1)C~(t)=f(C(t))
为了计算方便导数,假设激活函数为线性激活(即没有激活函数。实际上tanh在小值时可以近似为线性,relu在正数时也为线性),这时网络模型简化为:
C ~ ( t ) = W ⋅ x ( t ) + V ⋅ y ( t − 1 ) y ( t ) = C ( t ) = C ( t − 1 ) ⋅ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \begin{aligned} \tilde{C}(t) &=W \cdot x(t)+ V \cdot y(t-1) \\ y(t)& =C(t)= C(t-1) \cdot (W \cdot x(t)+V \cdot y(t-1))\\ \end{aligned} C~(t)y(t)=Wx(t)+Vy(t1)=C(t)=C(t1)(Wx(t)+Vy(t1))
假如网络经过了 T T T个时间步到输出端计算 l o s s loss loss,这时若要更新 t = 0 t=0 t=0时刻下网络参数 V V V的权重,则即对 t = 0 t=0 t=0时刻的参数 V V V求偏导,即计算
∂ loss ⁡ ( t = T ) ∂ V ( t = 0 ) \frac{\partial \operatorname{loss}(t=T)}{\partial V(t=0)} V(t=0)loss(t=T)

loss ⁡ ( t = T ) = f loss ( y ( t = T ) ) \operatorname{loss}(t=T)=f_{\text {loss}}(y(t=T)) loss(t=T)=floss(y(t=T))

(其中的 f l o s s ( ⋅ ) f_{loss}(·) floss()为损失函数)

链式求导法则 t = 0 t = 0 t=0时刻的矩阵进行参数更新时
∂ loss ⁡ ( t = T ) ∂ V ( t = 0 ) = ∂ f loss ( y ( t = T ) ) ∂ V ( t = 0 ) = ∂ f loss ( y ( t = T ) ) ∂ y ( t = T ) ∂ y ( t = T ) ∂ V ( t = 0 ) \begin{array}{l} \frac{\partial \operatorname{loss}(t=T)}{\partial V(t=0)}&=\frac{\partial f_{\text {loss}}(y(t=T))}{\partial V(t=0)}\\ &=\frac{\partial f_{\text {loss}}(y(t=T))}{\partial y(t=T)} \frac{\partial y(t=T)}{\partial V(t=0)}\\ \end{array} V(t=0)loss(t=T)=V(t=0)floss(y(t=T))=y(t=T)floss(y(t=T))V(t=0)y(t=T)
中的 ∂ f loss ( y ( t = T ) ) ∂ y ( t = T ) \frac{\partial f_{\text {loss}}(y(t=T))}{\partial y(t=T)} y(t=T)floss(y(t=T))的值就是我们要往前传的梯度(参数更新信息),这一项是针对本单元进行更新的所以不会涉及到指数级别的衰减或者增长,则我们的目标就是讨论 ∂ y ( t = T ) ∂ V ( t = 0 ) \frac{\partial y(t=T)}{\partial V(t=0)} V(t=0)y(t=T)

V V V求偏导时其他变量(就是说的 W W W x x x)自然也就成了常量,这里我们再做一个过分简化,直接删掉 y ( t ) y(t) y(t)中与 V V V无关的项 W ⋅ x ( t ) W \cdot x(t) Wx(t)!同时因为是线性激活 f ( ⋅ ) f(\cdot) f()也被省略。同时在y二阶乘方存在的情况下忽略一阶乘方),这时就可以直接展开。
C ~ ( 1 ) = V ⋅ y ( 0 ) C ~ ( 2 ) = V ⋅ y ( 1 ) = V ⋅ C ( 1 ) = V ⋅ C ( 0 ) ⋅ C ~ ( 1 ) = V ⋅ C ( 0 ) ⋅ V ⋅ y ( 0 ) = 常数 ⋅ V 2 ⋅ y ( 0 ) … C ~ ( T ) = 常数 ⋅ V T ⋅ y ( 0 ) \begin{array}{l} \tilde{C}(1) &=& V\cdot y(0)\\ \tilde{C}(2) &=& V\cdot y(1)=V\cdot C(1) = V\cdot C(0) \cdot \tilde{C}(1) = V\cdot C(0) \cdot V\cdot y(0) \\ &=& \text{常数}\cdot V^2\cdot y(0)\\ &\dots&\\ \tilde{C}(T) &=&\text{常数}\cdot V^T\cdot y(0) \end{array} C~(1)C~(2)C~(T)====Vy(0)Vy(1)=VC(1)=VC(0)C~(1)=VC(0)Vy(0)常数V2y(0)常数VTy(0)

所以输出值可以带换成关于 V V V的公式

y ( t = T ) = f ( C ( T ) ) = C ( T ) = C ( T − 1 ) ⋅ C ~ ( T ) = C ( 0 ) ⋅ C ~ ( 1 ) ⋅ C ~ ( 2 ) ⋯ ⋅ C ~ ( T − 1 ) C ~ ( T ) = 常数 ⋅ V 1 ⋅ y ( 0 ) ⋅ V 2 ⋅ y ( 0 ) ⋯ ⋅ V T ⋅ y ( 0 ) = 常数 ⋅ V ( 1 + T ) ⋅ T 2 ⋅ y ( 0 ) T \begin{array}{l} y(t=T) &= f(C(T)) = C(T)\\ &=C(T-1) \cdot \tilde{C}(T)\\ &=C(0)\cdot \tilde{C}(1) \cdot \tilde{C}(2) \dots \cdot \tilde{C}(T-1)\tilde{C}(T)\\ &=\text{常数} \cdot V^1 \cdot y(0)\cdot V^2 \cdot y(0) \dots \cdot V^T \cdot y(0)\\ &=\text{常数} \cdot V^{\frac{(1+T)\cdot T}{2}} \cdot {y(0)}^T \end{array} y(t=T)=f(C(T))=C(T)=C(T1)C~(T)=C(0)C~(1)C~(2)C~(T1)C~(T)=常数V1y(0)V2y(0)VTy(0)=常数V2(1+T)Ty(0)T

所以的迭代下去会发现如果说RNN的 V T V^T VT是根据T指数级梯度爆炸和消失,那这 种信息加载方式会形成 ( 1 + T ) ⋅ T 2 {\frac{(1+T)\cdot T}{2}} 2(1+T)T级别的指数梯度爆炸和消失。

所以说直接将新信息乘进长时记忆单元只会让情况更糟糕,导致当初 C ( t ) = C ( t − 1 ) C(t)=C(t-1) C(t)=C(t1)让导数恒为1的构想完全失效,这也说明了乘性更新并不是简单的信息叠加,而是控制和scaling。

论加法:

如果改成加性规则呢?此时添加信息的数学模型为
C ( t ) = C ( t − 1 ) + C ~ ( t ) C(t)=C(t-1) + \tilde{C}(t) C(t)=C(t1)+C~(t)
与前面的做法一样,假设线性激活并代入网络模型后得到
y ( t ) = C ( t ) = C ( t − 1 ) + C ~ ( t ) = C ( t − 1 ) + W ⋅ x ( t ) + V ⋅ y ( t − 1 ) = C ( t − 1 ) + W ⋅ x ( t ) + V ⋅ C ( t − 1 ) = ( V + 1 ) ⋅ y ( t − 1 ) + W ⋅ x ( t ) \begin{aligned} y(t)&=C(t) = C(t-1)+\tilde{C}(t) \\ &=C(t-1)+W \cdot x(t)+V \cdot y(t-1) \\ &=C(t-1)+W \cdot x(t)+V \cdot C(t-1) \\ &=(V+1) \cdot y(t-1)+ W \cdot x(t) \end{aligned} y(t)=C(t)=C(t1)+C~(t)=C(t1)+Wx(t)+Vy(t1)=C(t1)+Wx(t)+VC(t1)=(V+1)y(t1)+Wx(t)
最终 y ( T ) y(T) y(T)在传递中导数也会存在指数项。不过由于 V V V加了一个偏置1,导致爆炸的可能性远远大于消失。不过通过做梯度截断,也能很大程度的缓解梯度爆炸的影响。所以梯度消失的概率小了很多,梯度爆炸也能勉强缓解,看起来比RNN靠谱多了,毕竟控制好爆炸的前提下,梯度消失的越慢,记忆的距离就越长嘛。

因此,在往长时记忆单元添加信息方面,加性规则要显著优于乘性规则。也证明了加法更适合做信息叠加,而乘法更适合做控制和scaling。

由此,我们就确定只能应用加性规则,至此我们设计的网络应该是这样子的:
c ( t ) = c ( t − 1 ) + c ~ ( t ) c ^ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) y ( t ) = f ( c ( t ) ) \begin{aligned} c(t)&=c(t-1)+\tilde{c}(t) \\ \hat{c}(t)&=f(W \cdot x(t)+V \cdot y(t-1))\\ y(t)&=f(c(t)) \end{aligned} c(t)c^(t)y(t)=c(t1)+c~(t)=f(Wx(t)+Vy(t1))=f(c(t))
那么有没有办法让信息装箱和运输同时存在的情况下,让梯度消失的可能性变的更低,让梯度爆炸的可能性和程度也更低呢?

我们往长时记忆单元添加新信息的频率肯定是很低的,现实生活中只有很少的时刻我们可以记很久,大部分时刻的信息没过几天就忘了。因此现在这种模型一股脑的试图永远记住每个时刻的信息的做法肯定是不合理的,我们应该只记忆应该被记住的信息。

显然,对新信息选择记或者不记当前信息是一个控制操作,应该使用乘性规则。因此在新信息前加一个控制阀门,只需要让公式变为
c ( t ) = c ( t − 1 ) + g i n ⋅ c ~ ( t ) c ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) y ( t ) = f ( c ( t ) ) \begin{array}{l} c(t)=c(t-1)+g_{in} \cdot \tilde{c}(t) \\ \tilde{c}(t)=f(W \cdot x(t)+V \cdot y(t-1))\\ y(t)=f(c(t)) \end{array} c(t)=c(t1)+ginc~(t)c~(t)=f(Wx(t)+Vy(t1))y(t)=f(c(t))
这个 g i n g_{in} gin我们就叫做 “输入门” ,取值[0,1]。为了实现这个取值范围,我们很容易想到使用 sigmoid函数作为输入门的激活函数,毕竟sigmoid的输出范围一定是在[0,1]。当然,这是对一个时间记忆状态值的控制。我们到时候肯定要设置很多记忆单元去存储状态值,要不然会出现记忆信息量不足。因此每个长时记忆单元都有它专属的输入门。第二章中所有的例子都是基于元素的,当输入或者输出为向量时,需要通过向量形式的多个门去控制信号的传递或者遗忘,这个时候就是用按位相乘(个人想法)。 由于输入门只会在必要的时候开启,因此大部分情况下公式可以看直接进行状态信息的传递 C ( t ) = C ( t − 1 ) C(t)=C(t-1) C(t)=C(t1),也就是我们最理想的状态。由此加性操作带来的梯度爆炸也大大减轻,梯度消失更更更轻。

问题三:频繁装填带来的问题

万一神经网络读到一段信息量很大的文本,每一个信息都想被记住,以致于这时输入门一直保持大开状态,狼吞虎咽的试图记住所有这些信息。就会导致 C ( t ) C(t) C(t)的值变的非常大!

要知道,我们的网络要输出的时候是要把 C C C激活,当 C C C变的很大时, s i g m o i d ( ⋅ ) sigmoid(\cdot) sigmoid() t a n h ( ⋅ ) tanh(\cdot) tanh()这些常见的激活函数的输出就完全饱和了!比如如图 t a n h ( ⋅ ) tanh(\cdot) tanh()
在这里插入图片描述

C C C很大时, t a n h ( ⋅ ) tanh(\cdot) tanh()趋近于1,这时 C C C变得再大也没有什么意义了,因为处在函数的饱和区。反映到现实生活中就是当前的大脑容量记不住这么多东西和历史信息了,如果再往里面加新的信息,决策也不会因为有新的状态信息的加入而有什么变化。

这种情况怎么办呢?显然relu函数这种正向无饱和的激活函数是一种选择,但是我们想要的是适用范围更广的模型。

所以解决方案就是还需要加一个门用来遗忘,“遗忘门”。这样每个时刻到来的时候,记忆要先通过遗忘门忘掉一些事情,也就是为了让历史状态信息尽可能的保证在一个灵敏的状态,再考虑要不要接受这个时刻的新信息。

显然,遗忘门是用来控制记忆消失程度的,因此也要用乘性运算,即我们设计的网络已进化成:
C ( t ) = g forget  ⋅ C ( t − 1 ) + g i n ⋅ C ~ ( t ) C(t)=g_{\text {forget }}\cdot C(t-1)+g_{i n} \cdot \tilde{C}(t) C(t)=gforget C(t1)+ginC~(t)
至此通过可控的去遗忘之前的信息,从而接受新信息,解决如何为长时记忆单元中让信息状态一直保持灵敏激活状态。

问题四:网络如何输出,决策如何传递

当前的输出为什么不能设置成仅仅激活当前的记忆?

直接使用RNN的 y ( t ) = f ( c ( t ) ) y(t)=f(c(t)) y(t)=f(c(t))会有什么问题?(其中 f ( ⋅ ) f(·) f()为激活函数)

首先第一个要思考的点是:当前的决策输出是否需要与所有的状态信息相关,换句话说我现在的决定是否需要考虑到所有的历史记录信息?假如人有1万个长时记忆的脑细胞,每个脑细胞记一件事情,那么我们在处理眼前的事情的时候是每个时刻都把这1万个脑细胞里的事情都回忆一遍吗?显然不需要,我们只会让其中一部分跟当前任务当前时刻相关的脑细胞输出,即应该给我们的长时记忆单元添加一个输出阀门!

其次第二个要思考的点是:当前的决策输出是否需要原封不动的全部传递到下面的时间单元中。好像也是不需要的。所以LSTM添加了输出阀门。
y ( t ) = g out ⋅ f ( c ( t ) ) y(t)=g_{\text {out}} \cdot f(c(t)) y(t)=goutf(c(t))

针对输出信号进行控制传递。

问题五:控制门受什么控制

那么我们最后再定义一下控制门们(输入门、遗忘门、输出门)受谁的控制。

这个问题在RNN中很明显因为没有历史状态信息 C ( t ) C(t) C(t)的存在,所以让各个门受当前时刻的外部输入 x ( t ) x(t) x(t)和上一时刻的输出 y ( t − 1 ) y(t-1) y(t1)的控制是理所当然的。

但是在我们这个新设计的网络中,多了一堆阀门。尤其注意到输出门,一旦输出门关闭,就会导致其控制的记忆 f ( c ( t ) ) f(c(t)) f(c(t))被截断,下一时刻各个门就仅仅受当前时刻的外部输入x(t)控制了!这显然不符合我们的设计初衷(尽可能的让决策考虑到尽可能久的历史信息)。

解决方案就是再把长时记忆单元作为输入接入各个门,即把上一时刻的长时记忆 c ( t − 1 ) c(t-1) c(t1)接入遗忘门和输入门,把当前时刻的长时记忆 c ( t ) c(t) c(t)接入输出门(当信息流动到输出门的时候,当前时刻的长时记忆已经被计算完成了。即
g in ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t − 1 ) ) g forget ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t − 1 ) ) g out ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ c ( t ) ) \begin{array}{l} g_{\text {in}}(t)=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t-1)) \\ g_{\text {forget}}(t)=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t-1)) \\ g_{\text {out}}(t)=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot c(t)) \end{array} gin(t)=σ(Wx(t)+Vy(t1)+Uc(t1))gforget(t)=σ(Wx(t)+Vy(t1)+Uc(t1))gout(t)=σ(Wx(t)+Vy(t1)+Uc(t))
当然,这个让各个门考虑长时记忆的做法是后人打的补丁,这些从长时记忆单元到门单元的连接被称为 “peephole(猫眼)” 。这就是之后的LSTM变种第一种变种,在后文LSTM变种中会讲到。

总结一下逐步推进到这里的LSTM数学表达

其实到这里可以看出相较于RNN最大的加强与改进就是通过历史信息 C ( t ) C(t) C(t)的加入,达到了不仅仅通过输出 y ( t ) y(t) y(t),来进行信息传递的目的。为了以示区别使用 h ( t ) h(t) h(t)代替 y ( t ) y(t) y(t),因为RNN中 y ( t ) y(t) y(t)承担了信息传递的所有责任,但是 h ( t ) h(t) h(t)随时都可以被LSTM的输出门截断,所以我们可以很感性的把 h ( t ) h(t) h(t)理解为短时记忆单元。而从数学上看的话,更是短时记忆了,因为梯度流经 h h h的时候,经历的是 h ( t ) − > c ( t ) − > h ( t − 1 ) h(t)->c(t)->h(t-1) h(t)>c(t)>h(t1)的连环相乘的路径(在输入输出门关闭前),显然如前边的数学证明中所述,这样会发生梯度爆炸和消失,而梯度消失的时候就意味着记忆消失了,即 h ( t ) h(t) h(t)为短时记忆单元。

同样的思路可以再证明一下,由于长时间的历史状态信息只从 C C C走,存在一条无连环相乘的路径,由于加性模型可以避免梯度消失。又有遗忘门避免激活函数和梯度饱和,因此 C C C为长时记忆单元。

同时针对 x ( t ) x(t) x(t) h ( t ) h(t) h(t)我们设置了不同的权重矩阵 W , V W,V W,V,但其实用一个矩阵也可以,偏置矩阵暂时省略。
C ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) C ( t ) = g forget ⋅ C ( t − 1 ) + g in ⋅ C ~ ( t ) h ( t ) = g out ⋅ f ( C ( t ) ) g in ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g forget ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g out ( t ) = σ ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t ) ) \begin{aligned} \tilde{C}(t)&=f(W \cdot x(t)+V \cdot y(t-1))\\ C(t)&=g_{\text {forget}} \cdot C(t-1)+g_{\text{in}}\cdot \tilde{C}(t)\\ h(t)&=g_{\text {out}} \cdot f(C(t)) \\ g_{\text {in}}(t)&=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t-1)) \\ g_{\text {forget}}(t)&=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t-1)) \\ g_{\text {out}}(t)&=\sigma(W \cdot x(t)+V \cdot y(t-1)+U \cdot C(t)) \end{aligned} C~(t)C(t)h(t)gin(t)gforget(t)gout(t)=f(Wx(t)+Vy(t1))=gforgetC(t1)+ginC~(t)=goutf(C(t))=σ(Wx(t)+Vy(t1)+UC(t1))=σ(Wx(t)+Vy(t1)+UC(t1))=σ(Wx(t)+Vy(t1)+UC(t))
至于门控输入和第二章提出的模型不太相同第二章提出的模型是门控输入是没有 U ⋅ C ( t ) U \cdot C(t) UC(t)的,其实这只是有无猫眼的两种LSTM的变体模式,有的认为应该由长时间历史信息 C C C,短时间决策信息 h h h,当前输入 x x x三者决定每个门的控制信号;无猫眼的情况下认为由短时间决策信息 h h h,当前输入 x x x二者决定每个门的控制信号即可。

五、通过名字的思考进行LSTM结构的总结

1、为了解决RNN中的梯度消失的问题,为了让梯度无损传播,想到了 C ( t ) = C ( t − 1 ) C(t)=C(t-1) C(t)=C(t1)这个朴素却没毛病的梯度传播模型,于是称c为“长时记忆单元”。

2、为了把新信息平稳安全可靠的装入长时记忆单元,引入了“输入门”。

3、为了解决新信息装载次数过多带来的激活函数饱和的问题,引入了“遗忘门”。

4、为了让网络能够选择合适的记忆进行输出,引入了“输出门”。

5、为了解决记忆被输出门截断后使得各个门单元受控性降低的问题,我们引入了“peephole”连接。

6、由于输出门的截断性,区别于RNN中单独承担信息传递责任的 y ( t ) y(t) y(t),发现 h h h中存储的模糊历史记忆是短时的,于是记 h h h为短时记忆单元。

7、于是该网络既具备长时记忆,又具备短时记忆,就干脆起名叫长短时记忆神经网络(Long Short Term Memory Neural Networks,简称LSTM)

六、参考文献

1.Hochreiter S, Schmidhuber J. Long Short-TermMemory[J]. Neural Computation,1997, 9(8): 1735-1780.

2.Gers F A, Schmidhuber J, Cummins F, et al.Learning to Forget: Continual Prediction with
LSTM[J]. Neural Computation,2000, 12(10): 2451-2471.

3.Gers F A,Schraudolph N N, Schmidhuber J, etal. Learning precise timing with lstm recurrent networks[J]. Journal of MachineLearning Research, 2003, 3(1):115-143.

4.A guide to recurrent neural networks and backpropagation. Mikael Bod ́en.

5.http://colah.github.io/posts/2015-08-Understanding-LSTMs/

6.《Supervised Sequence Labelling with Recurrent Neural Networks》Alex Graves

7.《Hands on machine learning with sklearn and tf》Aurelien Geron

8.《Deep learning》Goodfellow et.

9.Step-by-step to LSTM: 解析LSTM神经网络设计原理 - 知乎

10.深入理解LSTM神经网络

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值