LSTM公式详解&推导

本文深入解析LSTM(长短时记忆网络),旨在解决RNN中的梯度消失问题。文章介绍了LSTM的结构、流程,详细推导了前向传播和反向传播的公式,包括输入门、遗忘门、输出门以及Cell状态的计算,为读者提供了深入理解LSTM的基础。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

书籍简介

《Surpervised Sequence Labelling with Recurrent Neural Network》(《用循环神经网络进行序列标记》),RNN(Recurrent Neural Network,循环神经网络)经典教材,由多伦多大学Alexander Graves所著,详细叙述了各种RNN模型及其推导。本文介绍该书的LSTM部分。对于该书,想深入了解的朋友点这里获取资源。
1

LSTM理解

  LSTM(Long Short-Term Memory Networks,长短时记忆网络),由Hochreiter和Schmidhuber于1997年提出,目的是解决一般循环神经网络中存在的梯度爆炸(输入信息激活后权重过小)及梯度消失(例如sigmoid、tanh的激活值在输入很大时其梯度趋于零)问题,主要通过引入门和Cell状态的概念来实现梯度的调整,已被大量应用于时间序列预测等深度学习领域。
  下面的描述主要侧重公式推导,对LSTM来由更详细的讨论请见《Step-by-step to LSTM: 解析LSTM神经网络设计原理》。

LSTM流程简介

  LSTM采用了门控输出的方式,即三门(输入门、遗忘门、输出门)两态(Cell State长时、Hidden State短时)。其核心即Cell State,指用于信息传播的Cell的状态,在结构示意图(图1,图源Understanding LSTMs,略改动)中是最上面的直链(从 C t − 1 C_{t-1} Ct1 C t C_t Ct)。


1

图1

  Memory Cell 接受两个输入,即上一时刻的输出值 h t − 1 h_{t-1} ht1和本时刻的输入值 x t x_t xt,由这两个参数 先进入遗忘门,得到决定要舍弃的信息 f t f_t ft(即权重较小的信息)后,再进入输入门,得到决定要更新的信息 i t i_t it(即与上一Cell相比权重较大的信息)以及当前时刻的Cell状态 C ~ t \tilde{C}_t C~t(候选向量,可理解为中间变量,存储当前 Cell State 信息),最后由这两个门(遗忘门,输入门)的输出值(即 f t , i t , C t ~ f_t,i_t,\tilde{C_t} ft,it,Ct~)进行组合(上一Cell状态 C t − 1 × C^{t-1}\times Ct1×要遗忘信息的激活值 f t f_t ft 与 当前时刻Cell状态 C t ~ × \tilde{C_t}\times Ct~×需要记忆信息的激活值 i t i_t it进行叠加,从图中可以更直观得到),得到分别的长时( C t C_t Ct)和短时( h t h_t ht)信息,最后进行存储操作及对下一个神经元的输入。下图2介绍了LSTM在网络中是如何工作的。


2

图2


根据图1,可依次得到三个门的形式方程如下(符号与图中保持一致):

  1. 遗忘门:

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma\left(W_f\cdot[h_{t-1}, x_t]+b_f\right) ft=σ(Wf[ht1,xt]+bf)

  1. 输入门:

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma\left(W_i\cdot[h_{t-1}, x_t]+b_i\right) it=σ(Wi[ht1,xt]+bi)

C t ~ = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C_t}=\tanh\left(W_C\cdot[h_{t-1}, x_t]+b_C\right) Ct~=tanh(WC[ht1,xt]+bC)

以及 t t t时刻的Cell 状态(长时)方程:

C t = f t ⋅ C t − 1 + i t ⋅ C t ~ C_t=f_t\cdot C_{t-1}+i_t\cdot \tilde{C_t} Ct=ftCt1+itCt~

  1. 输出门:

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t=\sigma\left(W_o\cdot[h_{t-1}, x_t]+b_o\right) ot=σ(Wo[ht1,xt]+bo)

h t = o t ⋅ tanh ⁡ ( C t ) h_t=o_t\cdot\tanh{(C_t)} ht=ottanh(Ct)

算法及公式

  根据上面的描述及图1,首先定义如下符号(符号为方便理解,与书中保持一致):

一些函数

  • f f f的激活函数
  • g g gCell输入的激活函数
  • h h hCell输出的激活函数

  • L \mathcal{L} L : 训练模型时的损失函数
  • σ ( z ) \sigma(z) σ(z):Sigmoid激活函数
    σ ( z ) = 1 1 + e − z = 1 + tanh ⁡ ( z / 2 ) 2 , \sigma(z)=\frac{1}{1+\mathrm{e}^{-z}}=\frac{1+\tanh(z/2)}{2}, σ(z)=1+ez1=21+tanh(z/2),

σ ′ ( z ) = σ ( z ) [ 1 − σ ( z ) ] . \sigma'(z)=\sigma(z)[1-\sigma(z)]. σ(z)=σ(z)[1σ(z)].

  • tanh ⁡ ( z ) \tanh(z) tanh(z):tanh激活函数
    tanh ⁡ ( z ) = e z − e − z e z + e − z , \tanh(z)=\frac{\mathrm{e}^z-\mathrm{e}^{-z}}{\mathrm{e}^z+\mathrm{e}^{-z}}, tanh(z)=ez+ezezez,

tanh ⁡ ′ ( z ) = 1 − tanh ⁡ 2 ( z ) . \tanh'(z)=1-\tanh^2(z). tanh(z)=1tanh2(z).

一些符号

  • I I I输入层 信息的数量
  • K K K输出层 信息的数量
  • H H H隐层 Cell状态的数量(注意这里的Cell与下面的Cell不同,代表短时记忆Cell),指图1中最下面的一条直链,即从 h t − 1 h_{t-1} ht1 h t h_t ht,处理短时记忆
  • C C CCell状态信息(长时记忆状态)的数量
  • T T T :总时间数(网络层总数),即 t = 0 , 1 , 2 , ⋯   , T t=0,1,2,\cdots,T t=0,1,2,,T

  • ϕ \phi ϕ :下标,指一个LSTM单元的遗忘门
  • ι \iota ι :下标,指一个LSTM单元的输入门
  • ω \omega ω :下标,指一个LSTM单元的输出门
  • c c c :下标,指神经元中某一个 C C C 记忆元胞(Cell)

  • w i j w_{ij} wij :从单元 i i i到单元 j j j权重
  • b j t b_j^t bjt t t t时刻第 j j j个单元的激活值,在 t = 0 t=0 t=0时初始化为 0 0 0
  • a j t a_j^t ajt t t t时刻第 j j j个单元的带权输入,可作抽象定义如下

a j t = ∑ i w i j b i t − 1 . a_j^t=\sum_{i}{w_{ij}b_{i}^{t-1}}. ajt=iwijbit1.

  • s c t s_c^t sct t t t时刻记忆元胞 c c c状态(State),在 t = 0 t=0 t=0时初始化为 0 0 0

  • δ j t \delta_j^t δjt t t t时刻第 j j j个单元的误差,在 t = T + 1 t=T+1 t=T+1时初始化为 0 0 0。一般化的定义为

δ j t = ∂ L ∂ a j t . \delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t}. δjt=ajtL.

前向传播

由上述的形式方程,很容易得到下面的前向传播公式:

  1. 遗忘门。由图1可知,遗忘门的输出依赖三个变量(图1中表示为左下角的两个输入和左上角的一个输入),分别是:上一时刻 ( t − 1 ) (t-1) (t1)神经元的短时记忆输出 h t − 1 h_{t-1} ht1,本时刻 ( t ) (t) (t)神经元的输入 x t x_t xt以及上一时刻 ( t − 1 ) (t-1) (t1)神经元的长时记忆输出Cell状态 s c t − 1 s_c^{t-1} sct1,乘以权重因子后对层数求和即可得到遗忘门的输入值及激活值如下:

a ϕ t = ∑ i = 1 I w i ϕ x i t + ∑ h = 1 H w h ϕ b h t − 1 + ∑ c = 1 C w c ϕ s c t − 1 (1.1) a_\phi^t=\sum_{i=1}^Iw_{i\phi}x_i^t+\sum_{h=1}^{H}w_{h\phi}b_h^{t-1}+\sum_{c=1}^Cw_{c\phi}s_c^{t-1}\tag{1.1} aϕt=i=1Iwiϕxit+h=1Hwhϕbht1+c=1Cwcϕsct1(1.1)

b ϕ t = f ( a ϕ t ) (1.2) b_\phi^t=f(a_\phi^t)\tag{1.2} bϕt=f(aϕt)(1.2)

  1. 输入门。其输出所依赖的变量与遗忘门相同,故同理可得

a ι t = ∑ i = 1 I w i ι x i t + ∑ h = 1 H w h ι b h t − 1 + ∑ c = 1 C w c ι s c t − 1 (2.1) a_\iota^t=\sum_{i=1}^Iw_{i\iota}x_i^t+\sum_{h=1}^{H}w_{h\iota}b_h^{t-1}+\sum_{c=1}^Cw_{c\iota}s_c^{t-1}\tag{2.1} aιt=i=1Iwiιxit+h=1Hwhιbht1+c=1Cwsct1(2.1)

b ι t = f ( a ι t ) (2.2) b_\iota^t=f(a_\iota^t)\tag{2.2} bιt=f(aιt)(2.2)

  1. Cell状态。由输入门的 t t t时刻的Cell 状态(长时)方程立即可得。

a c t = ∑ i = 1 I w i c x i t + ∑ h = 1 H w h c b h t − 1 (3.1) a_c^t =\sum_{i=1}^I w_{ic}x_i^t+\sum_{h=1}^H w_{hc}b_h^{t-1}\tag{3.1} act=i=1Iwicxit+h=1Hwhcbht1(3.1)

一一对应形式方程即可得到 s c t s_c^t sct表达式如下

C t = f t ⋅ C t − 1 + i t ⋅ C t ~ ⋮    ⋮         ⋮    ⋮       ⋮ s c t = b ϕ t ⋅ s c t − 1   + b ι t ⋅ g ( a c t ) (3.2) \begin{aligned} C_t&=f_t\cdot C_{t-1}+i_t\cdot \tilde{C_t} \\ \vdots& \quad\ \ \vdots\ \ \ \ \ \ \ \vdots \qquad \ \ \vdots\ \ \ \ \ \vdots\\ s_c^{t}&= b_\phi^t \cdot s_c^{t-1}\,+b_\iota^t \cdot g(a_c^t)\tag{3.2} \end{aligned} Ctsct=ftCt1+itCt~                =bϕtsct1+bιtg(act)(3.2)

  1. 输出门。由遗忘门同理可得

a ω t = ∑ i = 1 I w i ω x i t + ∑ h = 1 H w h ω b h t − 1 + ∑ c = 1 C w c ω s c t − 1 (4.1) a_\omega^t=\sum_{i=1}^Iw_{i\omega}x_i^t+\sum_{h=1}^{H}w_{h\omega}b_h^{t-1}+\sum_{c=1}^Cw_{c\omega}s_c^{t-1}\tag{4.1} aωt=i=1Iwxit+h=1Hwbht1+c=1Cwcωsct1(4.1)

b ω t = f ( a ω t ) (4.2) b_\omega^t=f(a_\omega^t)\tag{4.2} bωt=f(aωt)(4.2)

  1. Cell输出。指激活后的Cell状态(短时记忆),同理可由形式方程一一对应得到,即

h t = o t   ⋅   tanh ⁡ ( C t ) ⋮     ⋮   ⋮ b c t = b ω t ⋅    h ( s c t ) (5.1) \begin{aligned}h_t&=o_t\ \cdot\ \tanh{(C_t)}\tag{5.1} \\ \vdots&\ \ \ \quad\vdots\ \qquad \vdots\\ b_c^t&=b_\omega^t \cdot \ \ h(s_c^t)\end{aligned} htbct=ot  tanh(Ct)    =bωt  h(sct)(5.1)

反向传播

  重头戏来了!建议不熟悉反向传播的朋友看一下我的另一篇文章nndl学习笔记(二)反向传播公式推导,帮助你快速理解&回顾反向传播。

  同样地,为了与前向传播对应,这里也采用五个部分进行证明。反向传播,其目的就是通过计算损失函数关于权重和偏置的偏导数(本例中不对偏置进行分析),从而得到每一个神经元上出现的误差(误差定义为损失函数对神经元输入的偏导数),最后均摊给每个神经元,以此逐步减小误差。因为需要反向传播,所以顺序与前向传播正好相反(从后往前计算)。

关于误差的定义

  • Cell 输出的误差(短时记忆) ϵ c t = ∂ L ∂ b c t \epsilon_c^t=\frac{\partial \mathcal{L}}{\partial b_c^t} ϵct=bctL
  • Cell 状态的误差(长时记忆) ϵ s t = ∂ L ∂ s c t \epsilon_s^t=\frac{\partial \mathcal{L}}{\partial s_c^t} ϵst=sctL
  • δ j t \delta_j^t δjt t t t时刻第 j j j个单元的误差,在 t = T + 1 t=T+1 t=T+1时初始化为 0 0 0。定义为

δ j t = ∂ L ∂ a j t \delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t} δjt=ajtL

公式推导

  这些公式的核心,都是根据链式法则求偏导数,需要注意损失函数与哪些变量有关,找准变量,再应用求导法则,即可轻松计算出表达式。

  1. Cell输出(短时记忆)。
    首先找Cell输出与哪些量有关,从图1可以得知其只与隐层(Cell短时记忆状态)和输出层两个部分的信息有关,再根据误差定义 δ j t = ∂ L ∂ a j t \delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t} δjt=ajtL,可以得到:
    ϵ c t = ∂ L ∂ b c t = ∂ L ∂ a j t ∂ a j t ∂ b c t = ∑ h = 1 H ∂ L ∂ a h t + 1 ∂ a h t + 1 ∂ b c t + ∑ k = 1 K ∂ L ∂ a k t ∂ a k t ∂ b c t = ∑ h = 1 H δ h t + 1 ∂ a h t + 1 ∂ b c t + ∑ k = 1 K δ k t ∂ a k t ∂ b c t \begin{aligned} \epsilon_c^t &=\frac{\partial \mathcal{L}}{\partial b_c^t} =\frac{\partial \mathcal{L}}{\partial a_j^t} \frac{\partial a_j^t}{\partial b_c^t} \\ &= \sum_{h=1}^H\frac{\partial \mathcal{L}}{\partial a_h^{t+1}} \frac{\partial a_h^{t+1}}{\partial b_c^t}+\sum_{k=1}^K\frac{\partial \mathcal{L}}{\partial a_k^t} \frac{\partial a_k^t}{\partial b_c^t} \\ &=\sum_{h=1}^H\delta_h^{t+1} \frac{\partial a_h^{t+1}}{\partial b_c^{t}} + \sum_{k=1}^K\delta_k^t \frac{\partial a_k^t}{\partial b_c^t} \end{aligned} ϵct=bctL=ajtLbctajt=h=1Haht+1Lbctaht+1+k=1KaktLbctakt=h=1Hδht+1bctaht+1+k=1Kδktbctakt
    注意到这里 H H H层时间状态取 t + 1 t+1 t+1 K K K层取 t t t,是为了与前向传播式子的意义保持一致,即:隐层Cell状态前向传播需要前一时刻 ( t − 1 ) (t-1) (t1)的隐层Cell状态,而输出只需与本时刻输入的时刻 ( t ) (t) (t)一致即可,而反向传播正好相反(具体可见图1)。
    再根据带权输入的一般定义(同上,需要根据情况构造定义式,即: H H H层时刻变化而 K K K层时刻保持不变)
    a j t = ∑ i w i j b i t − 1 a_j^t=\sum_{i}{w_{ij}b_{i}^{t-1}} ajt=iwijbit1
    代入得到(注意这里有一步化简,去掉求和号,具体原因可见nndl学习笔记(二)反向传播公式推导公式一的推导部分):
    ϵ c t = ∑ h = 1 H δ h t + 1 ∂ ( w c h b c t ) ∂ b c t + ∑ k = 1 K δ k t ∂ ( w c k b c t ) ∂ b c t = ∑ h = 1 H δ h t + 1 w c h + ∑ k = 1 K δ k t w c k \begin{aligned} \epsilon_c^t&=\sum_{h=1}^H\delta_h^{t+1} \frac{\partial (w_{ch}b_c^{t})}{\partial b_c^{t}}+\sum_{k=1}^K\delta_k^t \frac{\partial (w_{ck}b_c^{t})}{\partial b_c^{t}} \\ &=\sum_{h=1}^H\delta_h^{t+1}w_{ch}+\sum_{k=1}^K\delta_k^tw_{ck} \end{aligned} ϵct=h=1Hδht+1bct(wchbct)+k=1Kδktbct(wckbct)=h=1Hδht+1wch+k=1Kδktwck

  2. 输出门。
    这里只需用到误差定义式 ϵ c t = ∂ L ∂ b c t \epsilon_c^t=\frac{\partial \mathcal{L}}{\partial b_c^t} ϵct=bctL及前向传播的 ( 5.1 ) (5.1) (5.1),最后一步求和是指针对所有神经元输出门激活值误差的叠加。
    δ ω t = ∂ L ∂ a ω t = ∂ L ∂ b ω t ∂ b ω t ∂ a ω t = ∂ L ∂ b ω t f ′ ( a ω t ) = f ′ ( a ω t ) ∂ L ∂ b c t ∂ b c t ∂ b ω t = f ′ ( a ω t ) ϵ c t ∂ b c t ∂ b ω t = f ′ ( a ω t ) ϵ c t ∂ [ b ω t h ( s c t ) ] ∂ b ω t = f ′ ( a ω t ) ∑ c = 1 C h ( s c t ) ϵ c t \begin{aligned} \delta_\omega^t&=\frac{\partial \mathcal{L}}{\partial a_\omega^t} =\frac{\partial \mathcal{L}}{\partial b_\omega^t}\frac{\partial b_\omega^t}{\partial a_\omega^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\omega^t} f'(a_\omega^t) \\ &=f'(a_\omega^t) \frac{\partial \mathcal{L}}{\partial b_c^t} \frac{\partial b_c^t}{\partial b_\omega^t}\\ &= f'(a_\omega^t) \epsilon_c^t \frac{\partial b_c^t}{\partial b_\omega^t} \\ &= f'(a_\omega^t) \epsilon_c^t \frac{\partial \left[b_\omega^t h(s_c^t)\right]}{\partial b_\omega^t} \\ &=f'(a_\omega^t)\sum_{c=1}^Ch(s_c^t)\epsilon_c^t \end{aligned} δωt=aωtL=bωtLaωtbωt=bωtLf(aωt)=f(aωt)bctLbωtbct=f(aωt)ϵctbωtbct=f(aωt)ϵctbωt[bωth(sct)]=f(aωt)c=1Ch(sct)ϵct

  3. Cell状态(长时记忆)。最长的一个式子,但是把握好变量之间的关系就可以轻松得出( 直接寻找前向传播众多公式中哪个含有变量 s c t s_c^t sct,这样再进行链式法则处理,会更加直观,由于五个式子都含有 s c t s_c^t sct,故下面第四个等号后的式子有五项)。
      推导过程与Cell输出(短时记忆)部分类似,要用到误差的一般定义 δ j t = ∂ L ∂ a j t \delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t} δjt=ajtL,并注意到本时刻Cell状态(长时记忆)是由上一时刻遗忘门 ( ϕ ) (\phi) (ϕ)和输入门 ( ι ) (\iota) (ι)的输出共同决定的(反映在图上就是图1中上面直链的加号);在反向传播中,除了需要将Cell状态(长时记忆)的时间取反 ( s c t + 1 ) (s_c^{t+1}) (sct+1),还要考虑三个门误差的积累(第二个等号后式子第一项),注意这里计算输出门误差时没有取后一时刻 t + 1 t+1 t+1,是因为遗忘门和输入门的误差在前向传播时会传递给下一时刻的带权输入,故反向传播需要后一时刻来计算误差;而输出门误差在本时刻即可计算。反映到方程上为第二个等号后的方程。
    ϵ s t = ∂ L ∂ s c t = ∂ L ∂ a j t + 1 ∂ a j t + 1 ∂ s c t + ∂ L ∂ b c t ∂ b c t ∂ s c t + ∂ L ∂ s c t + 1 ∂ s c t + 1 ∂ s c t = δ j t + 1 ∂ a j t + 1 ∂ s c t + ϵ c t ∂ [ b ω t h ( s c t ) ] ∂ s c t + ϵ s t + 1 ∂ [ b ϕ t + 1 ⋅ s c t   + b ι t + 1 ⋅ g ( a c t + 1 ) ] ∂ s c t = δ ϕ t + 1 ∂ a ϕ t + 1 ∂ s c t + δ ι t + 1 ∂ a ι t + 1 ∂ s c t + δ ω t ∂ a ω t + 1 ∂ s c t + ϵ c t b ω t h ′ ( s c t ) + ϵ s t + 1 b ϕ t + 1 = δ ϕ t + 1 ∂ ( ∑ i = 1 I w i ϕ x i t + 1 + ∑ h = 1 H w h ϕ b h t + ∑ c = 1 C w c ϕ s c t ) ∂ s c t + δ ι t + 1 ∂ ( ∑ i = 1 I w i ι x i t + 1 + ∑ h = 1 H w h ι b h t + ∑ c = 1 C w c ι s c t ) ∂ s c t + δ ω t ∂ ( ∑ i = 1 I w i ω x i t + 1 + ∑ h = 1 H w h ω b h t + ∑ c = 1 C w c ω s c t ) ∂ s c t + ϵ c t b ω t h ′ ( s c t ) + ϵ s t + 1 b ϕ t + 1 = ϵ c t b ω t h ′ ( s c t ) + ϵ s t + 1 b ϕ t + 1 + δ ϕ t + 1 w c ϕ + δ ι t + 1 w c ι + δ ω t w c ω \begin{aligned} \epsilon_s^t &=\frac{\partial \mathcal{L}}{\partial s_c^t} \\ &=\frac{\partial \mathcal{L}}{\partial a_j^{t+1}} \frac{\partial a_j^{t+1}}{\partial s_c^t} + \frac{\partial \mathcal{L}}{\partial b_c^t} \frac{\partial b_c^t}{\partial s_c^t} + \frac{\partial \mathcal{L}}{\partial s_c^{t+1}} \frac{\partial s_c^{t+1}}{\partial s_c^t} \\ &= \delta_j^{t+1} \frac{\partial a_j^{t+1}}{\partial s_c^t} + \epsilon_c^t\frac{\partial \left[b_\omega^t h(s_c^t) \right]}{\partial s_c^t} + \epsilon_s^{t+1} \frac{\partial \left[ b_\phi^{t+1} \cdot s_c^{t}\,+b_\iota^{t+1} \cdot g(a_c^{t+1}) \right]}{\partial s_c^t} \\ &= \delta_\phi^{t+1} \frac{\partial a_\phi^{t+1}}{\partial s_c^t} + \delta_\iota^{t+1} \frac{\partial a_\iota^{t+1}}{\partial s_c^t} + \delta_\omega^t \frac{\partial a_\omega^{t+1}}{\partial s_c^t} + \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} \\ &= \delta_\phi^{t+1} \frac{\partial \left( \sum_{i=1}^Iw_{i\phi}x_i^{t+1}+\sum_{h=1}^{H}w_{h\phi}b_h^{t}+\sum_{c=1}^Cw_{c\phi}s_c^{t} \right)}{\partial s_c^t} \\ &+ \delta_\iota^{t+1} \frac{\partial \left( \sum_{i=1}^Iw_{i\iota}x_i^{t+1}+\sum_{h=1}^{H}w_{h\iota}b_h^{t}+\sum_{c=1}^Cw_{c\iota}s_c^{t} \right)}{\partial s_c^t} \\ &+ \delta_\omega^t \frac{\partial \left( \sum_{i=1}^Iw_{i\omega}x_i^{t+1}+\sum_{h=1}^{H}w_{h\omega}b_h^{t}+ \sum_{c=1}^Cw_{c\omega}s_c^{t} \right)}{\partial s_c^t} \\ &+ \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} \\ &= \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} + \delta_\phi^{t+1}w_{c\phi} + \delta_\iota^{t+1} w_{c\iota} + \delta_\omega^t w_{c\omega} \\ \end{aligned} ϵst=sctL=ajt+1Lsctajt+1+bctLsctbct+sct+1Lsctsct+1=δjt+1sctajt+1+ϵctsct[bωth(sct)]+ϵst+1sct[bϕt+1sct+bιt+1g(act+1)]=δϕt+1sctaϕt+1+διt+1sctaιt+1+δωtsctaωt+1+ϵctbωth(sct)+ϵst+1bϕt+1=δϕt+1sct(i=1Iwiϕxit+1+h=1Hwhϕbht+c=1Cwcϕsct)+διt+1sct(i=1Iwiιxit+1+h=1Hwhιbht+c=1Cwsct)+δωtsct(i=1Iwxit+1+h=1Hwbht+c=1Cwcωsct)+ϵctbωth(sct)+ϵst+1bϕt+1=ϵctbωth(sct)+ϵst+1bϕt+1+δϕt+1wcϕ+διt+1w+δωtwcω

  4. Cell输出(短时记忆)。
    只需应用前向传播的 ( 3.2 ) (3.2) (3.2),即可得到:
    δ c t = ∂ L ∂ a c t = ∂ L ∂ s c t ∂ s c t ∂ a c t = ϵ s t ∂ [ b ϕ t ⋅ s c t − 1   + b ι t ⋅ g ( a c t ) ] ∂ a c t = ϵ s t b ι t g ′ ( a c t ) \begin{aligned} \delta_c^t &=\frac{\partial \mathcal{L}}{\partial a_c^t} =\frac{\partial \mathcal{L}}{\partial s_c^t}\frac{\partial s_c^t}{\partial a_c^t} \\ &=\epsilon_s^t \frac{\partial \left[b_\phi^t \cdot s_c^{t-1}\,+b_\iota^t \cdot g(a_c^t)\right] }{\partial a_c^t} \\ &=\epsilon_s^t b_\iota^tg'(a_c^t) \\ \end{aligned} δct=actL=sctLactsct=ϵstact[bϕtsct1+bιtg(act)]=ϵstbιtg(act)

  5. 遗忘门。方法同输出门推导,只需应用前向传播的 ( 3.2 ) (3.2) (3.2),可立即得到:
    δ ϕ t = ∂ L ∂ a ϕ t = ∂ L ∂ b ϕ t ∂ b ϕ t ∂ a ϕ t = ∂ L ∂ b ϕ t f ′ ( a ϕ t ) = f ′ ( a ϕ t ) ∂ L ∂ s c t ∂ s c t ∂ b ϕ t = f ′ ( a ϕ t ) ϵ s t ∂ s c t ∂ b ϕ t = f ′ ( a ϕ t ) ϵ s t ∂ [ b ϕ t s c t − 1 + b ι t g ( a c t ) ] ∂ b ϕ t = f ′ ( a ϕ t ) ∑ c = 1 C s c t − 1 ϵ s t \begin{aligned} \delta_\phi^t&= \frac{\partial \mathcal{L}}{\partial a_\phi^t} =\frac{\partial \mathcal{L}}{\partial b_\phi^t}\frac{\partial b_\phi^t}{\partial a_\phi^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\phi^t} f'(a_\phi^t) \\ &=f'(a_\phi^t) \frac{\partial \mathcal{L}}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\phi^t}\\ &= f'(a_\phi^t) \epsilon_s^t \frac{\partial s_c^t}{\partial b_\phi^t} \\ &= f'(a_\phi^t) \epsilon_s^t \frac{\partial \left[b_\phi^t s_c^{t-1} + b_\iota^{t} g(a_c^t)\right]}{\partial b_\phi^t} \\ &=f'(a_\phi^t)\sum_{c=1}^Cs_c^{t-1}\epsilon_s^t \end{aligned} δϕt=aϕtL=bϕtLaϕtbϕt=bϕtLf(aϕt)=f(aϕt)sctLbϕtsct=f(aϕt)ϵstbϕtsct=f(aϕt)ϵstbϕt[bϕtsct1+bιtg(act)]=f(aϕt)c=1Csct1ϵst

  6. 输入门。方法同输出门,只需应用前向传播的 ( 3.2 ) (3.2) (3.2),即可得到:
    δ ι t = ∂ L ∂ a ι t = ∂ L ∂ b ι t ∂ b ι t ∂ a ι t = ∂ L ∂ b ι t f ′ ( a ι t ) = f ′ ( a ι t ) ∂ L ∂ s c t ∂ s c t ∂ b ι t = f ′ ( a ι t ) ϵ s t ∂ s c t ∂ b ι t = f ′ ( a ι t ) ϵ c t ∂ [ b ϕ t s c t − 1 + b ι t g ( a c t ) ] ∂ b ι t = f ′ ( a ι t ) ∑ c = 1 C g ( a c t ) ϵ s t \begin{aligned} \delta_\iota^t&=\frac{\partial \mathcal{L}}{\partial a_\iota^t} =\frac{\partial \mathcal{L}}{\partial b_\iota^t}\frac{\partial b_\iota^t}{\partial a_\iota^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\iota^t} f'(a_\iota^t) \\ &=f'(a_\iota^t) \frac{\partial \mathcal{L}}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\iota^t}\\ &= f'(a_\iota^t) \epsilon_s^t \frac{\partial s_c^t}{\partial b_\iota^t} \\ &= f'(a_\iota^t) \epsilon_c^t \frac{\partial \left[b_\phi^t s_c^{t-1} + b_\iota^{t} g(a_c^t)\right]}{\partial b_\iota^t} \\ &=f'(a_\iota^t)\sum_{c=1}^Cg(a_c^t) \epsilon_s^t \end{aligned} διt=aιtL=bιtLaιtbιt=bιtLf(aιt)=f(aιt)sctLbιtsct=f(aιt)ϵstbιtsct=f(aιt)ϵctbιt[bϕtsct1+bιtg(act)]=f(aιt)c=1Cg(act)ϵst

总结

  本文介绍了这本书LSTM部分(第四章)的流程详解及公式推导,其中难免会有些许错误,望大家指出。得到公式后,下一步就是编程实现了,这里可以参考另一篇文章零基础入门深度学习(6) - 长短时记忆网络(LSTM),有非常细致的讲解。第一篇万字长文(其实主要是公式多),如果有用就点个赞吧!

P.S. PPT绘图大法是真的香,对我这种小白十分友好,有兴趣的朋友可以玩玩😆。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zorchp

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值