LSTM的公式推导详解

导言

在Alex Graves的这篇论文《Supervised Sequence Labelling with Recurrent Neural Networks》中对LSTM进行了综述性的介绍,并对LSTM的Forward Pass和Backward Pass进行了公式推导。

这篇文章将用更简洁的图示和公式一步步对Forward和Backward进行推导,相信读者看完之后能对LSTM有更深入的理解。

如果读者对LSTM的由来和原理存在困惑,推荐DarkScope的这篇博客:《RNN以及LSTM的介绍和公式梳理》

一、LSTM的基础结构

LSTM的结构中每个时刻的隐层包含了多个memory blocks(一般我们采用一个block),每个block包含了多个memory cell,每个memory cell包含一个Cell和三个gate,一个基础的结构示例如下图: 
image

一个memory cell只能产出一个标量值,一个block能产出一个向量。

二、LSTM的前向传播(Forward Pass)

1. 引入

首先我们在上述LSTM的基础结构之上构造时序结构,这样让读者更清晰地看到Recurrent的结构:

LSTM的整体结构

这里我们有几个约定:

  1. 每个时刻的隐层包含一个block
  2. 每个block包含一个memory cell

下面前向传播我们则从Input开始,逐个求解Input Gate、Forget Gate、Cells Gate、Ouput Gate和最终的Output

这里需要申明的一点,推导过程严格按照上述图示LSTM的结构;论文中对相较于该文章的推导过程会有增加一些项,在每一个公式不一致的地方我都会有相应说明。

2. Input Gate(ιι) 的计算

Input Gate接受两个输入:

  1. 当前时刻的Input作为输入:xtxt
  2. 上一时刻同一block内所有Cell作为输入:st−1csct−1

该案例中每层仅有单个Block、单个cemory cell,可以忽略∑Cc=1∑c=1C,以下Forget Gate和Output Gate做相同处理。

Input Gate

最终Input Gate的输出为:

 

atι=∑i=1Iωiιxti+∑c=1Cωcιst−1caιt=∑i=1Iωiιxit+∑c=1Cωcιsct−1

 

 

btι=f(atι)bιt=f(aιt)

 

这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hbht−1作为输入,论文中atιaιt会增加一项∑Hh=1ωhιbt−1h∑h=1Hωhιbht−1。

3. Forget Gate(ϕϕ) 的计算

Forget Gate接受两个输入:

  1. 当前时刻的Input作为输入:xtxt
  2. 上一时刻同一block内所有Cell作为输入:st−1csct−1

Forget Gate

最终Forget Gate的输出为:

 

atϕ=∑i=1Iωiϕxti+∑c=1Cωcϕst−1caϕt=∑i=1Iωiϕxit+∑c=1Cωcϕsct−1

 

 

btϕ=f(atϕ)bϕt=f(aϕt)

 

这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hbht−1作为输入,论文中atϕaϕt会增加一项∑Hh=1ωhϕbt−1h∑h=1Hωhϕbht−1。

4. Cell(cc) 的计算

Cell的计算稍有些复杂,接受两个输入:

  1. Input Gate和Input输入的乘积
  2. Forget Gate和上一时刻对应Cell输出的乘积

Cell

最终Cell的输出为:

 

atc=∑i=1Iωicxtiact=∑i=1Iωicxit

 

 

stc=btϕst−1c+btιg(atc)sct=bϕtsct−1+bιtg(act)

 

这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hbht−1作为输入,论文中atcact会增加一项∑Hh=1ωhcbt−1h∑h=1Hωhcbht−1。

5. Output Gate(ωω) 的计算

Output Gate接受两个输入:

  1. 当前时刻的Input作为输入:xtxt
  2. 当前时刻同一block内所有Cell作为输入:stcsct

这里Output Gate接受“当前时刻Cell的输出”而不是“上一时刻Cell的输出”,是由于此时Cell的结果已经产出,我们控制Output Gate的输出直接采用Cell当前的结果就行了,无须使用上一时刻。

Output Gate

最终Output Gate的输出为:

 

atω=∑i=1Iωiωxti+∑c=1Cωcωstcaωt=∑i=1Iωiωxit+∑c=1Cωcωsct

 

 

btω=f(atω)bωt=f(aωt)

 

这里Cell还可以接受上一个时刻中其他gate链接过来的边,论文中atϕaϕt会增加一项∑Hh=1ωhϕbt−1h∑h=1Hωhϕbht−1,这里HH是泛指t-1时刻的Cell或三个Gate。

6. Cell Output(cc) 的计算

Cell Output的计算即将Output Gate和Cell做乘积即可。

Cell Output

最终Cell Output为:

 

btc=btωh(stc)bct=bωth(sct)

 

7. 小结

至此,整个Block从Input到Output整个Forward Pass已经结束,其中涉及三个Gate和中间Cell的计算,需要注意的是三个Gate使用的激活函数是ff,而Input的激活函数是gg、Cell输出的激活函数是hh。

这里读者需要注意,在整个计算过程中,当前时刻的三个Gate均可以从上一时刻的任意Gate中接受输入,在公式中存在体现,但是在图示中并未画出相应的边。我们可以认为只有上一时刻的Cell才和当前时刻的Cell或三个Gate相连。 
前向小结

三、LSTM的反向传播(Backward Pass)

1. 引入

此处在论文中使用“Backward Pass”一词,但其实即Back Propagation过程,利用链式求导求解整个LSTM中每个权重的梯度。

2. 损失函数的选择

为了通用起见,在此我们仅展示多分类问题的损失函数的选择,对于网络的最终输出我们利用softmaxsoftmax方程计算结果属于某一类的概率(此时结果属于k个类别的概率和为1)。

 

p(Ck|x)=yk=eak∑Kk′=1eak′p(Ck|x)=yk=eka∑k′=1Kek′a

 

注意,ykyk对akak的偏导为∂yk′∂ak=ykδkk′−ykyk′∂yk′∂ak=ykδkk′−ykyk′(δkk′δkk′当k==k′k==k′时为1,其他为0)

其中,对于网络输出a1,a2,...a1,a2,...对应我们可以得到p(C1|x),p(C2|x),...p(C1|x),p(C2|x),...,即给定输入xx输出类别为C1,C2,...C1,C2,...的概率。

这样损失函数(Loss Function)就很好定义了:对于k∈1,2,...,Kk∈1,2,...,K,网络输出的类别为k概率为ykyk,而真实值zkzk:

 

L(x,z)=−lnp(z|x)=−∑k=1KzklnykL(x,z)=−lnp(z|x)=−∑k=1Kzklnyk

 

3. 权重的更新

对于神经网络中的每一个权重,我们都需要找到对应的梯度,从而通过不断地用训练样本进行随机梯度下降找到全局最优解,那么首先我们需要知道哪些权重需要更新。

一般层次分明的神经网络有input层、hidden层和output层,层与层之间的权重比较直观;但在LSTM中通过公式才能找到对应的权重,和图示中的边并不是一一对应,下面我将LSTM的单个Block中需要更新的权重在图示上标示了出来:

权重

为了方便起见,这里需要申明的是:我们仅考虑上一时刻的Cell仅和当前时刻的Cell和三个Gate相连。

2. Cell Output的梯度

首先我们计算每一个输出类别的梯度: 

δtk========∂L(x,z)∂atk∂(−∑Kk′=1zk′lnyk′)atk−∑k′=1Kzk′∂lnyk′∂atk−∑k′=1Kzk′yk′∂yk′∂atk−∑k′=1Kzk′yk′(ykδkk′−ykyk′)−∑k′=1Kzk′yk′ykδkk′+∑k′=1Kzk′yk′ykyk′−zk+yk∑k′=1Kzk′yk−zkδkt=∂L(x,z)∂akt=∂(−∑k′=1Kzk′lnyk′)akt=−∑k′=1Kzk′∂lnyk′∂akt=−∑k′=1Kzk′yk′∂yk′∂akt=−∑k′=1Kzk′yk′(ykδkk′−ykyk′)=−∑k′=1Kzk′yk′ykδkk′+∑k′=1Kzk′yk′ykyk′=−zk+yk∑k′=1Kzk′=yk−zk

 

也即每一个输出类别的梯度仅和其预测值和真实值相关,这样对于Cell Output的梯度则可以通过链式求导法则推导出来:

 

ϵtc=∂L(x,z)∂btc=∑k=1K∂L(x,z)∂atk∂atk∂btc=∑k=1Kδtkωckϵct=∂L(x,z)∂bct=∑k=1K∂L(x,z)∂akt∂akt∂bct=∑k=1Kδktωck

 

由于Output还可以连接下一个时刻的一个Cell、三个Gate,那么下一个时刻的一个Cell、三个Gate的梯度则可以传递回当前时刻Output,所以在论文中存在额外项∑Gg=1ωcgδt+1g∑g=1Gωcgδgt+1,为简便起见,公式和图示中未包含。

Cell Output

3. Output Gate的梯度

根据链式求导法则,Output Gate的梯度可以由以下公式推导出来:

 

δtω=∂L(x,z)∂atω=∂L(x,z)∂btc∂btc∂btω∂btω∂atω=ϵtch(stc)f′(atw)δωt=∂L(x,z)∂aωt=∂L(x,z)∂bct∂bct∂bωt∂bωt∂aωt=ϵcth(sct)f′(awt)

 

另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Output Gate的梯度写成了f′(atw)∑Cc=1ϵtch(stc)f′(awt)∑c=1Cϵcth(sct),但推导过程一致。推导过程见下图,说明梯度汇总到单个Gate中:

Output Gate

4. Cell的梯度

细心的读者在这里会发现,Cell的计算结构和普遍的神经网络不太一样,让我们首先来回顾一下Cell部分的Forward计算过程:

 

atc=∑i=1Iωicxtiact=∑i=1Iωicxit

 

 

stc=btϕst−1c+btιg(atc)sct=bϕtsct−1+bιtg(act)

 

输入数据贡献给atcact,而Cell同时能够接受Input Gate和Forget Gate的输入。

这样梯度就直接从Cell向下传递:

 

δtc=∂L(x,z)∂atc=∂L(x,z)∂stc∂stc∂atc=∂L(x,z)∂stcbtιg′(atc)δct=∂L(x,z)∂act=∂L(x,z)∂sct∂sct∂act=∂L(x,z)∂sctbιtg′(act)

 

在这里,我们定义States,由于Cell的梯度可以由以下几个计算单元传递回来:

  1. 当前时刻的Cell Output
  2. 下一个时刻的Cell
  3. 下一个时刻的Input Gate
  4. 下一个时刻的Output Gate

那么States可以这样求解,上面1~4个能够回传梯度的计算单元和下面公式中一一对应: 

ϵts====∂L(x,z)∂stc∂Lt(x,z)∂stc+∂Lt+1(x,z)∂st+1c∂st+1c∂stc+∂Lt+1(x,z)∂at+1ι∂at+1ι∂stc+∂Lt+1(x,z)∂at+1ϕ∂at+1ϕ∂stc(∂L(x,z)∂atw∂atw∂stc+∂L(x,z)∂btc∂btc∂stc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕδtωωcω+ϵtcbtωh′(stc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕϵst=∂L(x,z)∂sct=∂Lt(x,z)∂sct+∂Lt+1(x,z)∂sct+1∂sct+1∂sct+∂Lt+1(x,z)∂aιt+1∂aιt+1∂sct+∂Lt+1(x,z)∂aϕt+1∂aϕt+1∂sct=(∂L(x,z)∂awt∂awt∂sct+∂L(x,z)∂bct∂bct∂sct)+bϕt+1ϵst+1+ωcιδιt+1+ωcϕδϕt+1=δωtωcω+ϵctbωth′(sct)+bϕt+1ϵst+1+ωcιδιt+1+ωcϕδϕt+1

 

那么: 

δtc=ϵtsbtιg′(atc)δct=ϵstbιtg′(act)

 

Cell

细心的读者会发现,论文中∂L(x,z)∂btc∂L(x,z)∂bct并没有求和,这里作者持保留态度,应该存在求和项。

同时由于Cell可以连接到下一个时刻的Forget Gate、Output Gate和Input Gate,那么下一时刻的这三个Gate则可以将梯度传播回来,所以在论文中我们会发现ϵtsϵst拥有这三项:bt+1ϕϵt+1sbϕt+1ϵst+1、ωclδt+1ιωclδιt+1和ωcϕδt+1ϕωcϕδϕt+1。

5. Forget Gate的梯度

Forget Gate的梯度计算就比较简单明了:

 

δtϕ=∂L(x,z)∂atϕ=∂L(x,z)∂stc∂stc∂btϕ∂btϕ∂atϕ=ϵtsst−1cf′(atϕ)δϕt=∂L(x,z)∂aϕt=∂L(x,z)∂sct∂sct∂bϕt∂bϕt∂aϕt=ϵstsct−1f′(aϕt)

 

Forget Gate

另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Forget Gate的梯度写成了f′(atϕ)∑Cc=1st−1cϵtsf′(aϕt)∑c=1Csct−1ϵst,但推导过程一致,说明梯度汇总到单个Gate中。

6. Input Gate的梯度

Input Gate的梯度计算如下:

 

δtι=∂L(x,z)∂atι=∂L(x,z)∂stc∂stc∂btι∂btι∂atι=ϵtsg(atc)f′(atι)διt=∂L(x,z)∂aιt=∂L(x,z)∂sct∂sct∂bιt∂bιt∂aιt=ϵstg(act)f′(aιt)

 

Input Gate

另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Input Gate的梯度写成了f′(atι)∑Cc=1g(atc)ϵtsf′(aιt)∑c=1Cg(act)ϵst,但推导过程一致,说明梯度汇总到单个Gate中。

7. 小结

至此,所有的梯度求解已经结束,同样我们将这个Backward Pass的所有公式列出来:

小结

剩下的事情即利用梯度去更新每个权重: 

Δωn=mΔωn−1−α∂L∂ωnΔωn=mΔωn−1−α∂L∂ωn

 

其中mΔωn−1mΔωn−1为上一次权重的更新值,且m∈[0,1]m∈[0,1];而∂L∂ωn∂L∂ωn即上面我们求到的每一个梯度。

例如每次更新ωiϕωiϕ的ΔΔ量即: 

Δωniϕ=mΔωn−1iϕ−αxiδtϕΔωiϕn=mΔωiϕn−1−αxiδϕt

 

其中δtϕδϕt即Forget Gate的梯度。

三、总结

以上就是LSTM中的前向和反向传播的公式推导,在这里作者仅以最简单的单个Cell的场景进行示例。

在实际工程实践中,常常会涉及到同一时刻多个Cell且互相之间的Gate存在连接,同时上一个时刻或下一个时刻的Cell和三个Gate之间同样存在复杂的连接关系。

但如果读者能够明晰上述的推导过程,那么无论多复杂都能够迎刃而解了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值