LSTM公式及理解

机器学习 专栏收录该内容
20 篇文章 0 订阅

LSTM的基本结构及推导

这部分内容基本都是来自Step-by-step to LSTM: 解析LSTM神经网络设计原理,只是摘录了部分内容并添加了一些贫僧的想法。

LSTM公式与结构

LSTM(Long Short Term Memory,长短期记忆,注意这里的“长短期”,后面会提到是什么意思)的作者是个有点奇怪的人1,他的名字是Jürgen Schmidhuber(发音也挺奇怪)。LSTM的作者很有意思,如果读者感兴趣的话可以自己去看看相关资料(一定要去看作者本人的个人网站)。

接下来我们直接看LSTM,放一张网络上极为常见的图:

在这里插入图片描述

图片来自Understanding LSTM Networks,(丑的要死的)红色字体是贫僧加的。

注意图中hidden state(短期记忆)和cell state(长期记忆)的传递,以及输出其实就是 h t h_t ht

然后就是LSTM的计算公式:
输入门:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
遗忘门:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma (W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
C t ~ = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C_t} = \tanh{(W_C \cdot [h_{t-1}, x_t] + b_C)} Ct~=tanh(WC[ht1,xt]+bC)
输出门:
o t = σ ( W o [ h t − 1 , x t ] + b o ) o_t = \sigma (W_o [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
两种记忆:
长记忆: C t = f t ∗ C t − 1 + i t ∗ C t ~ C_t = f_t * C_{t-1} + i_t * \tilde{C_t} Ct=ftCt1+itCt~
短记忆: h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)

LSTM最重要的概念就是三个门:输入、输出、遗忘门;以及两个记忆:长记忆C、短记忆h。只要弄懂了这三个门两种记忆就可以弄明白LSTM了。这里先留个印象就行,接下来我们展开来讲。

从RNN、长时记忆说起

接下来我们直接进入正题,首先我们先从LSTM的父类RNN(循环神经网络,重点在循环)说起。提出RNN的目的是为了解决一个问题:LSTM以前的神经网络没有记忆,无法根据之前的输入来预测出输出。那怎么给神经网络加上记忆
记忆就是过去发生的事情,那么如果神经网络现在的输出受到过去输入的影响,是不是就可以说神经网络拥有了记忆?于是就有了这样的神经网络:

y ( t ) = f ( X ( t ) ⋅ W + y ( t − 1 ) ⋅ V + b ) y(t) = f(X(t) \cdot W + y(t - 1) \cdot V + b) y(t)=f(X(t)W+y(t1)V+b)

这就是无隐藏层的循环神经网络,结构如下:

在这里插入图片描述
上图来自Step-by-step to LSTM: 解析LSTM神经网络设计原理,一如既往, x x x是输入, y y y是输出。此图与上面的公式无关,只是作为“过去的输入能够影响现在的输出”的例子,意会即可

注意,上图中的圆圈(神经“圆”)共享权重 W W W V V V,即上面画的神经元都是同一个神经元,只是代表同一个神经元在不同时序时候的状态。所有数据依次通过同一个cell然后cell不断更新自己权重2。RNN都这样,后面会详细说明的LSTM也是这样。这里的“时序”就是常说的time step的单个step。time step通常就是一个句子的长度3;如果是batch的话那time step就是batch中最长句子的长度,其余的句子会通过pad来进行补齐(例如补充0作为pad,这部分是工程细节,不是重点)。

正是因为上图中神经元共享同样的权重,所以就有了“记忆”。为什么可以称为是“记忆”呢?可以用人日常的思考方式来帮助理解。假设您(又一次)在人生的道路上迷路了,您会4

  1. 看看现在走到了哪里
  2. 回忆一下之前走过的路
  3. 结合1和2的信息来决定该继续往哪里走

而这里的回忆一定是模糊而抽象的地形图,您通常不会回忆细节,例如路上某颗树的具体形状。而上图中的神经网络就是这样工作的,它将上一个时序的输出传给了下一个时序的神经元,而神经元做的处理就是用 V V V来乘 y ( t − 1 ) y(t - 1) y(t1),相当于是从上一个输出中抽象出(或者说提取出)了部分信息来作为当前决策的辅助工具。因此,我们的神经网络在某种程度上就具备了记忆

把所谓的“经过抽象的信息”用隐结点 h h h表示就可以得到:
在这里插入图片描述
这就是standard RNN。

总结一下:通过对上一个时序的神经元的输出提取抽象信息,并与输入一同输入进神经元中,就可以赋予神经元【记忆】。

抽象化信息

定义抽象化了的信息为: c ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \tilde{c}(t) = f(W \cdot x(t) + V \cdot y(t - 1)) c~(t)=f(Wx(t)+Vy(t1))(先这么定义,后面会解释,反正记住 c ~ \tilde{c} c~是经过抽象化之后的信息就是了)。定义好抽象化信息是定义好了,问题来了,怎么装进到长时记忆单元?

将抽象出来的信息装载进长时记忆单元

有两种装载方法:

  1. 乘进去
  2. 加进去

这里我们用加法,具体为什么及相关推导过程可以看Step-by-step to LSTM: 解析LSTM神经网络设计原理,这里略要提一下结果。乘法更多是作为对信息控制的操作(类似阀门和放大器的组合);加法则是新信息和旧信息叠加的操作(其实这也挺符合常识,所以这里没有作太多解释)。此外,LSTM中长时记忆单元最怕的就是梯度爆炸和消失,如果用乘法的话就会导致梯度/爆炸的速度更快(具体推导过程看Step-by-step to LSTM: 解析LSTM神经网络设计原理,主要是贫僧不想敲公式了。。。反正不难)。总而言之,因为加法更适合做信息叠加,而乘法更适合做控制和scaling,所以我们使用加法来加载新的信息到记忆中。

装载信息: C ( t ) = C ( t − 1 ) + C ~ ( t ) C(t) = C(t - 1) + \tilde{C}(t) C(t)=C(t1)+C~(t)
抽象出来的信息: C ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \tilde{C}(t) = f(W \cdot x(t) + V \cdot y(t - 1)) C~(t)=f(Wx(t)+Vy(t1))(其实就是当前输入和上一时序输出加权求和)
神经元当前时序的输出: y ( t ) = f ( c ( t ) ) y(t) = f(c(t)) y(t)=f(c(t)) f ( …   ) f(\dots) f()是激活函数)

但是如果只是不加选择地往记忆中添加信息,那么记忆中的有用信息会被无用的信息埋没,所以我们应该只往记忆中添加需要的信息。我们通过添加输入门来实现对信息进行挑选。

输入门

这么做其实很简单,只需要对新信息添加控制的阀门就可以了,所以这里要用到乘法:

C ( t ) = C ( t − 1 ) + g i n ⋅ C ~ ( t ) C(t) = C(t - 1) + g_{in} \cdot \tilde{C}(t) C(t)=C(t1)+ginC~(t)

g i n g_{in} gin就是输入门,取值0~1。因为取值0~1,所以通常使用sigmoid函数作为激活函数,所以 g i n = s i g m o i d ( …   ) g_{in} = \mathrm{sigmoid}(\dots) gin=sigmoid()

通常记忆单元是 n × m n \times m n×m的矩阵,而 g i n g_{in} gin也是 n × m n \times m n×m的矩阵。同时, g i n g_{in} gin起阀门的作用,所以需要通过element-wise的相乘来控制装载与否,我们通常用 ⨂ \bigotimes 来表示这种操作。最后得到的公式:

C ( t ) = C ( t − 1 ) + g i n ⨂ C ~ ( t ) C(t) = C(t - 1) + g_{in} \bigotimes \tilde{C}(t) C(t)=C(t1)+ginC~(t)

不过输入门只在必要的时候开启,所以大部分情况下上述公式可以等价为: C ( t ) = C ( t − 1 ) C(t) = C(t - 1) C(t)=C(t1),这样就可以降低梯度爆炸/消失出现的可能性(因为这就相当于是不添加任何东西直接一直将同样的记忆往后传)。

遗忘门

如果输入门一直打开,就会有大量的信息涌入到记忆中,导致记忆 C C C的值变得非常大。因为输出是 y ( t ) = f ( C ( t ) ) y(t) = f(C(t)) y(t)=f(C(t)),而且通常激活函数选的是sigmoid、tanh这种激活函数,那么就会导致激活函数的输出饱和。例如tanh,在输入值很大的时候梯度基本消失了:

在这里插入图片描述

所以就需要添加个遗忘的机制来将记忆中的信息剔除,这个记忆就是遗忘门。遗忘门其实就是个阀门,所以这里还是用乘法实现:

C ( t ) = g f o r g e t C ( t − 1 ) + g i n ⨂ C ~ ( t ) C(t) = g_{forget}C(t - 1) + g_{in} \bigotimes \tilde{C}(t) C(t)=gforgetC(t1)+ginC~(t)

输出门

在处理事情的时候,通常人只会让其中一部分跟当前人母当前时刻相关的脑细胞输出,所以我们设计的神经网络也一样,要添加个输出门来控制负责输出的脑细胞:

y ( t ) = g o u t ⨂ f ( C ( t ) ) y(t) = g_{out} \bigotimes f(C(t)) y(t)=goutf(C(t))

Peephole(猫眼)

不要看错成“pee hole”了。

回到正题,当输出门因为某种原因关闭的时候就会导致记忆 C ( t ) C(t) C(t)被截断,这样每一个时序中门只受当前时刻的外部输入 x ( t ) x(t) x(t)控制了。为了解决这个问题就是直接把长时记忆单元接入到各个们,即接入 C ( t − 1 ) C(t - 1) C(t1)到遗忘、输入门,把 C ( t ) C(t) C(t)接入到输出门(因为信息流动到输出门时长时记忆已经计算完了,具体可以看最上面的图)。最后我们得到的公式就是:

g i n ( t ) = s i g m o i d ( W ⋅ x ( t ) ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{in}(t) = \mathrm{sigmoid}(W \cdot x(t)) + V \cdot y(t - 1) + U \cdot C(t - 1)) gin(t)=sigmoid(Wx(t))+Vy(t1)+UC(t1))
g f o r g e t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{forget}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot y(t - 1) + U \cdot C(t - 1)) gforget(t)=sigmoid(Wx(t)+Vy(t1)+UC(t1))
g o u t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t ) ) g_{out}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot y(t - 1) + U \cdot C(t)) gout(t)=sigmoid(Wx(t)+Vy(t1)+UC(t))

这些后来添加的连接叫做“Peephole”。

总结下现在得到的网络:

长时记忆:
C ( t ) = g f o r g e t C ( t − 1 ) + g i n ⨂ C ~ ( t ) C(t) = g_{forget}C(t - 1) + g_{in} \bigotimes \tilde{C}(t) C(t)=gforgetC(t1)+ginC~(t)
C ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) ) \tilde{C}(t) = f(W \cdot x(t) + V \cdot y(t - 1)) C~(t)=f(Wx(t)+Vy(t1))
y ( t ) = g o u t ⨂ f ( C ( t ) ) y(t) = g_{out} \bigotimes f(C(t)) y(t)=goutf(C(t))

g i n ( t ) = s i g m o i d ( W ⋅ x ( t ) ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{in}(t) = \mathrm{sigmoid}(W \cdot x(t)) + V \cdot y(t - 1) + U \cdot C(t - 1)) gin(t)=sigmoid(Wx(t))+Vy(t1)+UC(t1))
g f o r g e t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{forget}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot y(t - 1) + U \cdot C(t - 1)) gforget(t)=sigmoid(Wx(t)+Vy(t1)+UC(t1))
g o u t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ y ( t − 1 ) + U ⋅ C ( t ) ) g_{out}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot y(t - 1) + U \cdot C(t)) gout(t)=sigmoid(Wx(t)+Vy(t1)+UC(t))

引入短时记忆/隐藏层

直接用隐藏层单元 h h h代替最终输出 y y y,可得:

C ( t ) = g f o r g e t ⨂ C ( t − 1 ) + g i n ⨂ C ~ ( t ) C(t) = g_{forget} \bigotimes C(t - 1) + g_{in} \bigotimes \tilde{C}(t) C(t)=gforgetC(t1)+ginC~(t)
C ~ ( t ) = f ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) ) \tilde{C}(t) = f(W \cdot x(t) + V \cdot h(t - 1)) C~(t)=f(Wx(t)+Vh(t1))
h ( t ) = g o u t ⨂ f ( C ( t ) ) h(t) = g_{out} \bigotimes f(C(t)) h(t)=goutf(C(t))
y ( t ) = h ( t ) y(t) = h(t) y(t)=h(t)

g i n ( t ) = s i g m o i d ( W ⋅ x ( t ) ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{in}(t) = \mathrm{sigmoid}(W \cdot x(t)) + V \cdot h(t - 1) + U \cdot C(t - 1)) gin(t)=sigmoid(Wx(t))+Vh(t1)+UC(t1))
g f o r g e t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t − 1 ) ) g_{forget}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot h(t - 1) + U \cdot C(t - 1)) gforget(t)=sigmoid(Wx(t)+Vh(t1)+UC(t1))
g o u t ( t ) = s i g m o i d ( W ⋅ x ( t ) + V ⋅ h ( t − 1 ) + U ⋅ C ( t ) ) g_{out}(t) = \mathrm{sigmoid}(W \cdot x(t) + V \cdot h(t - 1) + U \cdot C(t)) gout(t)=sigmoid(Wx(t)+Vh(t1)+UC(t))

由于h随时都可以被输出门截断,所以我们可以很感性的把h理解为短时记忆单元。
而从数学上看的话,更是短时记忆了,因为梯度流经h的时候,经历的是h(t)->c(t)->h(t-1)的连环相乘的路径(在输入输出门关闭前),显然如前边的数学证明中所述,这样会发生梯度爆炸和消失,而梯度消失的时候就意味着记忆消失了,即h为短时记忆单元。
同样的思路可以再证明一下,由于梯度只从c走的时候,存在一条无连环相乘的路径,可以避免梯度消失。又有遗忘门避免激活函数和梯度饱和,因此c为长时记忆单元。
(引用自Step-by-step to LSTM: 解析LSTM神经网络设计原理

最后,我们就得到了LSTM(长短时记忆神经网络)。

参考

Step-by-step to LSTM: 解析LSTM神经网络设计原理:这一篇必须看,很接地气,本篇博文绝大部分内容借鉴了这篇文章的内容(某种程度上本博客就是这篇文章的笔记)
Understanding LSTM Networks:很多图片来自这篇文章(例如LSTM相关的基本都来自这篇文章)
理解LSTM Networks:上面这篇文章的翻译
难以置信!LSTM和GRU的解析从未如此清晰(动图+视频)


  1. 每个人都在使用LSTM,主流学术圈却只想让它的发明者闭嘴 ↩︎

  2. LSTM神经网络输入输出究竟是怎样的? ↩︎

  3. 此处假设您是将句子拆分成词并且一个词一个词地喂给循环神经网络 ↩︎

  4. 此处例子改自Step-by-step to LSTM: 解析LSTM神经网络设计原理,看不懂这个例子的话看这篇 ↩︎

©️2021 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值