循环神经网络2--LSTM

这周在看循环数据网络, 发现一个博客, 里面推导极其详细, 借此记录重点.

详细推导

强烈建议手推一遍, 虽然会花一点时间, 但便于理清思路.

长短时记忆网络

回顾BPTT算法里误差项沿时间反向传播的公式:

δTk=δTti=kt1diag[f(neti)]W(1) (1) δ k T = δ t T ∏ i = k t − 1 d i a g [ f ′ ( n e t i ) ] W

根据范数的性质, 来获取 δTk δ k T 的模的上界:
δTkδTti=kt1diag[f(neti)]WδTt(βfβW)tk(2)(3) (2) ‖ δ k T ‖ ⩽ ‖ δ t T ‖ ∏ i = k t − 1 ‖ d i a g [ f ′ ( n e t i ) ] ‖ ‖ W ‖ (3) ⩽ ‖ δ t T ‖ ( β f β W ) t − k

可以看到, 误差项 δ δ 从t时刻传递到k时刻, 其值上界是 βfβw β f β w 的指数函数. βfβw β f β w 分别是对角矩阵 diag[f(neti)] d i a g [ f ′ ( n e t i ) ] 和矩阵W模的上界. 显然, 当t-k很大时, 会有 梯度爆炸, 当t-k很小时, 会有 梯度消失.

为了解决RNN的梯度爆炸和梯度消失的问题, 就出现了长短时记忆网络(Long Short Memory Network, LSTM). 原始RNN的隐藏层只有一个状态h, 它对于短期的输入非常敏感. 如果再增加一个状态c, 让它来保存长期的状态, 那么就可以解决原始RNN无法处理长距离依赖的问题.

img

新增加的状态c, 称为单元状态(cell state). 上图按照时间维度展开:

img

上图中, 在t时刻, LSTM的输入有三个: 当前时刻网络的输入值 xt x t , 上一时刻LSTM的输出值 ht1 h t − 1 , 以及上一时刻的单元状态 ct1 c t − 1 ; LSTM的输出有两个: 当前时刻的LSTM输出 ht h t , 当前时刻的状态 ct c t . 其中 x,h,c x , h , c 都是向量.

LSTM的关键在于怎样控制长期状态c. 在这里, LSTM的思路是使用三个控制开关:

第一个开关, 负责控制继续保存长期状态c; (遗忘门)

第二个开关, 负责控制把即时状态输入到长期状态c; (输入门)

第三个开关, 负责控制是都把长期状态c作为当前的LSTM的输出. (输出门)

img

接下来, 具体描述一下输出h和单元状态c的计算方法.

长短时记忆网络的前向计算

开关在算法中用门(gate)实现. 门实际上就是一层全连接层, 它的输入是一个向量, 输出是一个0~1的实数向量. 假设w是门的权重向量, b是偏置项, 门可以表示为:

g(x)=σ(Wx+b) g ( x ) = σ ( W x + b )

门的使用, 就是 用门的输出向量按元素乘以我们需要控制的那个向量. 当门的输出为0时, 任何向量与之相乘都会得到0向量, 相当于什么都不能通过; 当输出为1时, 任何向量与之相乘都为本身, 相当于什么都可以通过. 上式中 σ σ 是sigmoid函数, 值域为(0,1), 所以门的状态是半开半闭的.

LSTM用两个门来控制单元状态c的内容, 一个是遗忘门(forget gate), 它决定了上一时刻的单元状态 ct1 c t − 1 有多少保留到当前时刻 ct c t ; 另一个是输入门(input gate), 它决定了当前时刻网络的输入 xt x t 有多少保存到单元状态 ct c t . LSTM用输出门(output gate)来控制单元状态 ct c t 有多少输出到LSTM的当前输出值 ht h t .

1. 遗忘门:

ft=σ(Wf[ht1,xt]+bf)(1) f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) ( 式 1 )

上式中, Wf W f 是遗忘门的权重矩阵, [ht1,xt] [ h t − 1 , x t ] 表示把两个向量连接到一个更长的向量, bf b f 是遗忘门的偏置项, σ σ 是sigmoid函数. 如果输入的维度是 dh d h , 单元状态的维度是 dc d c (通常 dc=dh d c = d h ), 则遗忘门的权重矩阵 Wf W f 维度是 dc×(dh+dx) d c × ( d h + d x ) .

事实上, 权重矩阵 Wf W f 都是两个矩阵拼接而成的: 一个是 Wfh W f h , 它对应着输入项 ht1 h t − 1 , 其维度为 dc×dh d c × d h ; 一个是 Wfx W f x , 它对应着输入项 xt x t , 其维度为 dc×dh d c × d h . Wf W f 可以写成:

[Wf][ht1xt]=[WfhWfx][ht1xt]=Wfhht1+Wfxxt(4)(5) (4) [ W f ] [ h t − 1 x t ] = [ W f h W f x ] [ h t − 1 x t ] (5) = W f h h t − 1 + W f x x t

下图是遗忘门的计算:

img

2. 输入门:

it=σ(Wi[ht1,xt]+bi)(2) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) ( 式 2 )

上式中, Wi W i 是输入门的权重矩阵, bi b i 是输入门的偏置项.

下图是输入门的计算:

img

接下来, 计算用于描述当前输入的单元状态 c̃t c ~ t , 它是根据根据上一次的输出和本次的输入来计算的:

c̃t=tanh(Wc[ht1,xt]+bc)(3) c ~ t = tanh ⁡ ( W c ⋅ [ h t − 1 , x t ] + b c ) ( 式 3 )

下图是 c̃t c ~ t 的计算:

img

现在, 我们计算当前时刻的单元状态 ct c t . 它是由上一次的单元状态 ct1 c t − 1 按元素乘以遗忘门 ft f t , 再用当前输入的单元状态 c̃t c ~ t 按元素乘以输入门 it i t , 再将两个积加和产生的:

ct=ftct1+itc̃t(4) c t = f t ∘ c t − 1 + i t ∘ c ~ t ( 式 4 )

符号 表示 按元素乘. 下图是 ct c t 的计算:

img

这样, 就把LSTM关于当前的记忆 c̃t c ~ t 和长期的记忆 ct1 c t − 1 组合在一起, 形成了新的单元状态 ct c t . 由于遗忘门的控制, 它可以保存很久之前的信息, 由于输入门的控制, 它又可以避免当前无关紧要的内容进入记忆.

3. 输出门

ot=σ(Wo[ht1,xt]+bo)(5) o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) ( 式 5 )

下图表示输出门的计算:

img

LSTM最终的输出, 是由输出门和单元状态共同确定的:

ht=ottanh(ct)(6) h t = o t ∘ tanh ⁡ ( c t ) ( 式 6 )

下图表示LSTM最终输出的计算:

img

式1式6就是LSTM前向计算的全部公式.

长短时记忆网络的训练

训练部分比前向计算部分复杂, 具体推导如下.

LSTM训练算法框架

LSTM的训练算法仍然是反向传播算法, 主要是三个步骤:

  1. 前向计算每个神经元的输出值, 对于LSTM来说, 即 ft,it,ctot,ht f t , i t , c t o t , h t 五个向量的值;
  2. 反向计算每个神经元的误差项 δ δ 值, 与RNN一样, LSTM误差项的反向传播也是包括两个方向: 一个沿时间的反向传播, 即从当前t时刻开始, 计算每个时刻的误差项; 一个是将误差项向上一层传播;
  3. 根据相应的误差项, 计算每个权重的梯度.

关于公式和符号的说明

接下来的推导, 设定gate的激活函数为sigmoid, 输出的激活函数为tanh函数. 他们的导数分别为:

σ(z)σ(z)tanh(z)tanh(z)=y=11+ez=y(1y)=y=ezezez+ez=1y2(6)(7)(8)(9) (6) σ ( z ) = y = 1 1 + e − z (7) σ ′ ( z ) = y ( 1 − y ) (8) tanh ⁡ ( z ) = y = e z − e − z e z + e − z (9) tanh ′ ⁡ ( z ) = 1 − y 2

从上式知, sigmoid函数和tanh函数的导数都是原函数的函数, 那么计算出原函数的值, 导数便也计算出来.

LSTM需要学习的参数共有8组, 权重矩阵的两部分在反向传播中使用不同的公式, 分别是:

  1. 遗忘门的权重矩阵 Wf W f 和偏置项 bt b t , Wf W f 分开为两个矩阵 Wfh W f h Wfx W f x
  2. 输入门的权重矩阵 Wi W i 和偏置项 bi b i , Wi W i 分开为两个矩阵 Wih W i h Wxi W x i
  3. 输出门的权重矩阵 Wo W o 和偏置项 bo b o , Wo W o 分开为两个矩阵 Woh W o h Wox W o x
  4. 计算单元状态的权重矩阵 Wc W c 和偏置项 bc b c , Wc W c 分开为两个矩阵 Wch W c h Wcx W c x

按元素乘 符号. 当 作用于两个向量时, 运算如下:

ab=a1a2a3...anb1b2b3...bn=a1b1a2b2a3b3...anbn a ∘ b = [ a 1 a 2 a 3 . . . a n ] ∘ [ b 1 b 2 b 3 . . . b n ] = [ a 1 b 1 a 2 b 2 a 3 b 3 . . . a n b n ]

作用于 一个向量一个矩阵时, 运算如下:
aX=a1a2a3...anx11x21x31xn1x12x22x32xn2x13x23x33...xn3............x1nx2nx3nxnn=a1x11a2x21a3x31anxn1a1x12a2x22a3x32anxn2a1x13a2x23a3x33...anxn3............a1x1na2x2na3x3nanxnn(10)(11) (10) a ∘ X = [ a 1 a 2 a 3 . . . a n ] ∘ [ x 11 x 12 x 13 . . . x 1 n x 21 x 22 x 23 . . . x 2 n x 31 x 32 x 33 . . . x 3 n . . . x n 1 x n 2 x n 3 . . . x n n ] (11) = [ a 1 x 11 a 1 x 12 a 1 x 13 . . . a 1 x 1 n a 2 x 21 a 2 x 22 a 2 x 23 . . . a 2 x 2 n a 3 x 31 a 3 x 32 a 3 x 33 . . . a 3 x 3 n . . . a n x n 1 a n x n 2 a n x n 3 . . . a n x n n ]

作用于 两个矩阵时, 两个矩阵对应位置的元素相乘. 按元素乘可以在某些情况下简化矩阵和向量运算.

例如, 当一个对角矩阵右乘一个矩阵时, 相当于用对角矩阵的对角线组成的向量按元素乘那个矩阵:

diag[a]X=aX d i a g [ a ] X = a ∘ X

当一个行向量左乘一个对角矩阵时, 相当于这个行向量按元素乘那个矩阵对角组成的向量:
aTdiag[b]=ab a T d i a g [ b ] = a ∘ b

在t时刻, LSTM的输出值为 ht h t . 我们定义t时刻的误差项 δt δ t 为:
δt=defEht δ t = d e f ∂ E ∂ h t

这里假设误差项是损失函数对输出值的导数, 而不是对加权输出 netlt n e t t l 的导数. 因为LSTM有四个加权输入, 分别对应 ft,it,ct,ot f t , i t , c t , o t , 我们希望往上一层传递一个误差项而不是四个, 但需要定义这四个加权输入以及它们对应的误差项.
netf,tneti,tnetc̃,tneto,tδf,tδi,tδc̃,tδo,t=Wf[ht1,xt]+bf=Wfhht1+Wfxxt+bf=Wi[ht1,xt]+bi=Wihht1+Wixxt+bi=Wc[ht1,xt]+bc=Wchht1+Wcxxt+bc=Wo[ht1,xt]+bo=Wohht1+Woxxt+bo=defEnetf,t=defEneti,t=defEnetc̃,t=defEneto,t(12)(13)(14)(15)(16)(17)(18)(19)(20)(21)(22)(23) (12) n e t f , t = W f [ h t − 1 , x t ] + b f (13) = W f h h t − 1 + W f x x t + b f (14) n e t i , t = W i [ h t − 1 , x t ] + b i (15) = W i h h t − 1 + W i x x t + b i (16) n e t c ~ , t = W c [ h t − 1 , x t ] + b c (17) = W c h h t − 1 + W c x x t + b c (18) n e t o , t = W o [ h t − 1 , x t ] + b o (19) = W o h h t − 1 + W o x x t + b o (20) δ f , t = d e f ∂ E ∂ n e t f , t (21) δ i , t = d e f ∂ E ∂ n e t i , t (22) δ c ~ , t = d e f ∂ E ∂ n e t c ~ , t (23) δ o , t = d e f ∂ E ∂ n e t o , t

误差项沿时间的反向传递

沿时间反向传递误差项, 就是要计算出t-1时刻的误差项 δt1 δ t − 1 .

δTt1=Eht1=Ehththt1=δTththt1(24)(25)(26) (24) δ t − 1 T = ∂ E ∂ h t − 1 (25) = ∂ E ∂ h t ∂ h t ∂ h t − 1 (26) = δ t T ∂ h t ∂ h t − 1

其中, htht1 ∂ h t ∂ h t − 1 是一个Jacobian矩阵, 为了求出它, 需要列出 ht h t 的计算公式, 即前面的 式6式4:
ht=ottanh(ct)(6)ct=ftct1+itc̃t(4) h t = o t ∘ tanh ⁡ ( c t ) ( 式 6 ) c t = f t ∘ c t − 1 + i t ∘ c ~ t ( 式 4 )

显然, ot,ft,it,c̃t o t , f t , i t , c ~ t 都是 ht1 h t − 1 的函数, 那么, 利用全导数公式可得:
δTththt1=δTthtototneto,tneto,tht1+δTthtctctftftnetf,tnetf,tht1+δTthtctctititneti,tneti,tht1+δTthtctctc̃tc̃tnetc̃,tnetc̃,tht1=δTo,tneto,tht1+δTf,tnetf,tht1+δTi,tneti,tht1+δTc̃,tnetc̃,tht1(7)(27)(28)(29) (27) δ t T ∂ h t ∂ h t − 1 = δ t T ∂ h t ∂ o t ∂ o t ∂ n e t o , t ∂ n e t o , t ∂ h t − 1 + δ t T ∂ h t ∂ c t ∂ c t ∂ f t ∂ f t ∂ n e t f , t ∂ n e t f , t ∂ h t − 1 (28) + δ t T ∂ h t ∂ c t ∂ c t ∂ i t ∂ i t ∂ n e t i , t ∂ n e t i , t ∂ h t − 1 + δ t T ∂ h t ∂ c t ∂ c t ∂ c ~ t ∂ c ~ t ∂ n e t c ~ , t ∂ n e t c ~ , t ∂ h t − 1 (29) = δ o , t T ∂ n e t o , t ∂ h t − 1 + δ f , t T ∂ n e t f , t ∂ h t − 1 + δ i , t T ∂ n e t i , t ∂ h t − 1 + δ c ~ , t T ∂ n e t c ~ , t ∂ h t − 1 ( 式 7 )

下面, 要把 式7中的每个偏导数都求出来, 根据 式6, 可以求出:
htothtct=diag[tanh(ct)]=diag[ot(1tanh(ct)2)](30)(31) (30) ∂ h t ∂ o t = d i a g [ tanh ⁡ ( c t ) ] (31) ∂ h t ∂ c t = d i a g [ o t ∘ ( 1 − tanh ⁡ ( c t ) 2 ) ]

根据 式4, 可以求出:
ctftctitctc̃t=diag[ct1]=diag[c̃t]=diag[it](32)(33)(34) (32) ∂ c t ∂ f t = d i a g [ c t − 1 ] (33) ∂ c t ∂ i t = d i a g [ c ~ t ] (34) ∂ c t ∂ c ~ t = d i a g [ i t ]

因为:
otneto,tftnetf,titneti,tc̃tnetc̃,t=σ(neto,t)=Wohht1+Woxxt+bo=σ(netf,t)=Wfhht1+Wfxxt+bf=σ(neti,t)=Wihht1+Wixxt+bi=tanh(netc̃,t)=Wchht1+Wcxxt+bc(35)(36)(37)(38)(39)(40)(41)(42)(43)(44)(45) (35) o t = σ ( n e t o , t ) (36) n e t o , t = W o h h t − 1 + W o x x t + b o (37) (38) f t = σ ( n e t f , t ) (39) n e t f , t = W f h h t − 1 + W f x x t + b f (40) (41) i t = σ ( n e t i , t ) (42) n e t i , t = W i h h t − 1 + W i x x t + b i (43) (44) c ~ t = tanh ⁡ ( n e t c ~ , t ) (45) n e t c ~ , t = W c h h t − 1 + W c x x t + b c

可以得出:
otneto,tneto,tht1ftnetf,tnetf,tht1itneti,tneti,tht1c̃tnetc̃,tnetc̃,tht1=diag[ot(1ot)]=Woh=diag[ft(1ft)]=Wfh=diag[it(1it)]=Wih=diag[1c̃2t]=Wch(46)(47)(48)(49)(50)(51)(52)(53) (46) ∂ o t ∂ n e t o , t = d i a g [ o t ∘ ( 1 − o t ) ] (47) ∂ n e t o , t ∂ h t − 1 = W o h (48) ∂ f t ∂ n e t f , t = d i a g [ f t ∘ ( 1 − f t ) ] (49) ∂ n e t f , t ∂ h t − 1 = W f h (50) ∂ i t ∂ n e t i , t = d i a g [ i t ∘ ( 1 − i t ) ] (51) ∂ n e t i , t ∂ h t − 1 = W i h (52) ∂ c ~ t ∂ n e t c ~ , t = d i a g [ 1 − c ~ t 2 ] (53) ∂ n e t c ~ , t ∂ h t − 1 = W c h

将上述偏导数导入到 式7, 可以得到:
δt1=δTo,tneto,tht1+δTf,tnetf,tht1+δTi,tneti,tht1+δTc̃,tnetc̃,tht1=δTo,tWoh+δTf,tWfh+δTi,tWih+δTc̃,tWch(8)(54)(55) (54) δ t − 1 = δ o , t T ∂ n e t o , t ∂ h t − 1 + δ f , t T ∂ n e t f , t ∂ h t − 1 + δ i , t T ∂ n e t i , t ∂ h t − 1 + δ c ~ , t T ∂ n e t c ~ , t ∂ h t − 1 (55) = δ o , t T W o h + δ f , t T W f h + δ i , t T W i h + δ c ~ , t T W c h ( 式 8 )

根据 δo,t,δf,t,δi,t,δc̃,t δ o , t , δ f , t , δ i , t , δ c ~ , t 的定义, 可知:
δTo,tδTf,tδTi,tδTc̃,t=δTttanh(ct)ot(1ot)(9)=δTtot(1tanh(ct)2)ct1ft(1ft)(10)=δTtot(1tanh(ct)2)c̃tit(1it)(11)=δTtot(1tanh(ct)2)it(1c̃2)(12)(56)(57)(58)(59) (56) δ o , t T = δ t T ∘ tanh ⁡ ( c t ) ∘ o t ∘ ( 1 − o t ) ( 式 9 ) (57) δ f , t T = δ t T ∘ o t ∘ ( 1 − tanh ⁡ ( c t ) 2 ) ∘ c t − 1 ∘ f t ∘ ( 1 − f t ) ( 式 10 ) (58) δ i , t T = δ t T ∘ o t ∘ ( 1 − tanh ⁡ ( c t ) 2 ) ∘ c ~ t ∘ i t ∘ ( 1 − i t ) ( 式 11 ) (59) δ c ~ , t T = δ t T ∘ o t ∘ ( 1 − tanh ⁡ ( c t ) 2 ) ∘ i t ∘ ( 1 − c ~ 2 ) ( 式 12 )

式8式12就是将误差沿时间反向传播一个时刻的公式. 有了它, 便可以写出将误差项传递到任意k时刻的公式:
δTk=j=kt1δTo,jWoh+δTf,jWfh+δTi,jWih+δTc̃,jWch(13) δ k T = ∏ j = k t − 1 δ o , j T W o h + δ f , j T W f h + δ i , j T W i h + δ c ~ , j T W c h ( 式 13 )

将误差项传递到上一层

假设当前是第 l l 层, 定义l1层的误差项是误差函数对 l1 l − 1 加权输入的导数, 即:

δl1t=defEnetl1t δ t l − 1 = d e f ∂ E n e t t l − 1

本次LSTM的输入 xt x t 由下面的公式计算:
xlt=fl1(netl1t) x t l = f l − 1 ( n e t t l − 1 )

上式中, fl1 f l − 1 表示第 l1 l − 1 激活函数.

因为 netlf,t,netli,t,netlc̃,t,netlo,t n e t f , t l , n e t i , t l , n e t c ~ , t l , n e t o , t l 都是 xt x t 的函数, xt x t 又是 netl1t n e t t l − 1 的函数, 因此, 要求出 E E netl1t n e t t l − 1 的导数, 就需要使用全导数公式:

Enetl1t=Enetlf,tnetlf,txltxltnetl1t+Enetli,tnetli,txltxltnetl1t+Enetlc̃,tnetlc̃,txltxltnetl1t+Enetlo,tnetlo,txltxltnetl1t=δTf,tWfxf(netl1t)+δTi,tWixf(netl1t)+δTc̃,tWcxf(netl1t)+δTo,tWoxf(netl1t)=(δTf,tWfx+δTi,tWix+δTc̃,tWcx+δTo,tWox)f(netl1t)(14)(60)(61)(62)(63) (60) ∂ E ∂ n e t t l − 1 = ∂ E ∂ n e t f , t l ∂ n e t f , t l ∂ x t l ∂ x t l ∂ n e t t l − 1 + ∂ E ∂ n e t i , t l ∂ n e t i , t l ∂ x t l ∂ x t l ∂ n e t t l − 1 (61) + ∂ E ∂ n e t c ~ , t l ∂ n e t c ~ , t l ∂ x t l ∂ x t l ∂ n e t t l − 1 + ∂ E ∂ n e t o , t l ∂ n e t o , t l ∂ x t l ∂ x t l ∂ n e t t l − 1 (62) = δ f , t T W f x ∘ f ′ ( n e t t l − 1 ) + δ i , t T W i x ∘ f ′ ( n e t t l − 1 ) + δ c ~ , t T W c x ∘ f ′ ( n e t t l − 1 ) + δ o , t T W o x ∘ f ′ ( n e t t l − 1 ) (63) = ( δ f , t T W f x + δ i , t T W i x + δ c ~ , t T W c x + δ o , t T W o x ) ∘ f ′ ( n e t t l − 1 ) ( 式 14 )

式14就是将误差传递到上一层的公式.

权重梯度的计算

对于 Wfh,Wih,Wch,Woh W f h , W i h , W c h , W o h 的权重梯度, 我们知道它的梯度是各个时刻梯度之和. 我们首先求出它们在t时刻的梯度, 然后再求出他们最终的梯度.

我们已经求得了误差项 δo,t,δf,t,δi,t,δc̃,t δ o , t , δ f , t , δ i , t , δ c ~ , t , 很容易求出t时刻的 Woh,Wfh,Wih,Wch W o h , W f h , W i h , W c h :

EWoh,tEWfh,tEWih,tEWch,t=Eneto,tneto,tWoh,t=δo,thTt1=Enetf,tnetf,tWfh,t=δf,thTt1=Eneti,tneti,tWih,t=δi,thTt1=Enetc̃,tnetc̃,tWch,t=δc̃,thTt1(64)(65)(66)(67)(68)(69)(70)(71)(72)(73)(74) (64) ∂ E ∂ W o h , t = ∂ E ∂ n e t o , t ∂ n e t o , t ∂ W o h , t (65) = δ o , t h t − 1 T (66) (67) ∂ E ∂ W f h , t = ∂ E ∂ n e t f , t ∂ n e t f , t ∂ W f h , t (68) = δ f , t h t − 1 T (69) (70) ∂ E ∂ W i h , t = ∂ E ∂ n e t i , t ∂ n e t i , t ∂ W i h , t (71) = δ i , t h t − 1 T (72) (73) ∂ E ∂ W c h , t = ∂ E ∂ n e t c ~ , t ∂ n e t c ~ , t ∂ W c h , t (74) = δ c ~ , t h t − 1 T

将各个时刻的梯度加在一起, 就能得到最终的梯度:

EWohEWfhEWihEWch=j=1tδo,jhTj1=j=1tδf,jhTj1=j=1tδi,jhTj1=j=1tδc̃,jhTj1(75)(76)(77)(78) (75) ∂ E ∂ W o h = ∑ j = 1 t δ o , j h j − 1 T (76) ∂ E ∂ W f h = ∑ j = 1 t δ f , j h j − 1 T (77) ∂ E ∂ W i h = ∑ j = 1 t δ i , j h j − 1 T (78) ∂ E ∂ W c h = ∑ j = 1 t δ c ~ , j h j − 1 T

对于偏置项 bf,bi,bc,bo b f , b i , b c , b o 的梯度, 先求出各个时刻的偏置项梯度:
Ebo,tEbf,tEbi,tEbc,t=Eneto,tneto,tbo,t=δo,t=Enetf,tnetf,tbf,t=δf,t=Eneti,tneti,tbi,t=δi,t=Enetc̃,tnetc̃,tbc,t=δc̃,t(79)(80)(81)(82)(83)(84)(85)(86)(87)(88)(89) (79) ∂ E ∂ b o , t = ∂ E ∂ n e t o , t ∂ n e t o , t ∂ b o , t (80) = δ o , t (81) (82) ∂ E ∂ b f , t = ∂ E ∂ n e t f , t ∂ n e t f , t ∂ b f , t (83) = δ f , t (84) (85) ∂ E ∂ b i , t = ∂ E ∂ n e t i , t ∂ n e t i , t ∂ b i , t (86) = δ i , t (87) (88) ∂ E ∂ b c , t = ∂ E ∂ n e t c ~ , t ∂ n e t c ~ , t ∂ b c , t (89) = δ c ~ , t

将各个时刻的偏置项梯度加在一起:
EboEbiEbfEbc=j=1tδo,j=j=1tδi,j=j=1tδf,j=j=1tδc̃,j(90)(91)(92)(93) (90) ∂ E ∂ b o = ∑ j = 1 t δ o , j (91) ∂ E ∂ b i = ∑ j = 1 t δ i , j (92) ∂ E ∂ b f = ∑ j = 1 t δ f , j (93) ∂ E ∂ b c = ∑ j = 1 t δ c ~ , j

对于 Wfx,Wix,Wcx,Wox W f x , W i x , W c x , W o x 的权重梯度, 只需要根据相应的误差项直接计算即可:
EWoxEWfxEWixEWcx=Eneto,tneto,tWox=δo,txTt=Enetf,tnetf,tWfx=δf,txTt=Eneti,tneti,tWix=δi,txTt=Enetc̃,tnetc̃,tWcx=δc̃,txTt(94)(95)(96)(97)(98)(99)(100)(101)(102)(103)(104) (94) ∂ E ∂ W o x = ∂ E ∂ n e t o , t ∂ n e t o , t ∂ W o x (95) = δ o , t x t T (96) (97) ∂ E ∂ W f x = ∂ E ∂ n e t f , t ∂ n e t f , t ∂ W f x (98) = δ f , t x t T (99) (100) ∂ E ∂ W i x = ∂ E ∂ n e t i , t ∂ n e t i , t ∂ W i x (101) = δ i , t x t T (102) (103) ∂ E ∂ W c x = ∂ E ∂ n e t c ~ , t ∂ n e t c ~ , t ∂ W c x (104) = δ c ~ , t x t T

以上就是LSTM的训练算法的全部公式

GRU

上面所述是一种普通的LSTM, 事实上LSTM存在很多变体, GRU就是其中一种最成功的变体. 它对LSTM做了很多简化, 同时保持和LSTM相同的效果.

GRU对LSTM做了两大改动:

  1. 将输入门, 遗忘门, 输出门变为两个门: 更新门(Update Gate) zt z t 和充值门(Reset Gate) rt r t .
  2. 将单元状态与输出合并为一个状态: h h

GRU的前向计算公式为:

ztrth̃th=σ(Wz[ht1,xt])=σ(Wr[ht1,xt])=tanh(W[rtht1,xt])=(1zt)ht1+zth̃t(105)(106)(107)(108) (105) z t = σ ( W z ⋅ [ h t − 1 , x t ] ) (106) r t = σ ( W r ⋅ [ h t − 1 , x t ] ) (107) h ~ t = tanh ⁡ ( W ⋅ [ r t ∘ h t − 1 , x t ] ) (108) h = ( 1 − z t ) ∘ h t − 1 + z t ∘ h ~ t

下图是GRU的示意图:

img

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值