引言
LSTM是RNN的变种,是为了解决RNN存在的长期依赖问题而专门设计出来的。所谓长期依赖问题是,后面的单词在很长的时间序列后还依赖前面的单词,但由于梯度消失问题,导致前面的单词无法影响到后面的单词。
LSTM单元
LSTM单元(cell)在每个时间点更新单元状态 c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩,它决定了 a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩的值。LSTM有更新门、遗忘门和输出门来控制这些值。
下面来对LSTM中的元素做一些说明
遗忘门
遗忘门用来控制内存中之前的状态是否会被遗忘掉。
- 如果遗忘门的值是0,LSTM会遗忘(忽略)之前的状态
- 如果遗忘门的值是1,LSTM会记得(保持)之前的状态
- 如果是0到1之间的值,代表LSTM会记得之前的状态多大程度
公式为:
Γ f ⟨ t ⟩ = σ ( W f [ a ⟨ t − 1 ⟩ , x ⟨ t ⟩ ] + b f ) (1) \mathbf{\Gamma}_f^{\langle t \rangle} = \sigma(\mathbf{W}_f[\mathbf{a}^{\langle t-1 \rangle}, \mathbf{x}^{\langle t \rangle}] + \mathbf{b}_f)\tag{1} Γf⟨t⟩=σ(Wf[a⟨t−1⟩,x⟨t⟩]+bf)(1)
- W f W_f Wf和 b f b_f bf是可学习的权重和偏差
- 通过sigmoid函数来保证输出的值在[0,1]之间
- 遗忘门 Γ f ⟨ t ⟩ \mathbf{\Gamma}_f^{\langle t \rangle} Γf⟨t⟩与之前单元状态 c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩同维度,即它们能逐元素相乘
在代码中Wf
代表
W
f
W_f
Wf,bf
代表
b
f
b_f
bf,ft
代表
Γ
f
⟨
t
⟩
\mathbf{\Gamma}_f^{\langle t \rangle}
Γf⟨t⟩
候选值 c ~ ⟨ t ⟩ \tilde{\mathbf{c}}^{\langle t \rangle} c~⟨t⟩
- 候选值保存的是当前时间点可能会存入当前单元状态( c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩)的信息
- 候选值能多大程度的存入当前单元状态取决于更新门
公式为:
c ~ ⟨ t ⟩ = tanh ( W c [ a ⟨ t − 1 ⟩ , x ⟨ t ⟩ ] + b c ) (2) \mathbf{\tilde{c}}^{\langle t \rangle} = \tanh\left( \mathbf{W}_{c} [\mathbf{a}^{\langle t - 1 \rangle}, \mathbf{x}^{\langle t \rangle}] + \mathbf{b}_{c} \right) \tag{2} c~⟨t⟩=tanh(Wc[a⟨t−1⟩,x⟨t⟩]+bc)(2)
这里用的是tanh函数,所以取值范围为[-1,1]
cct
代表
c
~
⟨
t
⟩
\tilde{\mathbf{c}}^{\langle t \rangle}
c~⟨t⟩
Wc
代表
W
c
W_c
Wc
更新门(输入门)
- 更新门决定候选值(哪些维度)能多大程度的存入当前单元状态
- 如果更新门的值是0,意味着防止候选值存入单元状态
- 如果更新门的值是1,意味着完全允许候选值存入单元状态
有些文献称它为输入门,并且用"i"来表示,这里沿用这种约定
公式:
Γ i ⟨ t ⟩ = σ ( W i [ a ⟨ t − 1 ⟩ , x ⟨ t ⟩ ] + b i ) (3) \mathbf{\Gamma}_i^{\langle t \rangle} = \sigma(\mathbf{W}_i[a^{\langle t-1 \rangle}, \mathbf{x}^{\langle t \rangle}] + \mathbf{b}_i)\tag{3} Γi⟨t⟩=σ(Wi[a⟨t−1⟩,x⟨t⟩]+bi)(3)
Wi
代表
W
i
W_i
Wi,bi
代表
b
i
b_i
bi,it
代表更新门
Γ
i
⟨
t
⟩
\mathbf{\Gamma}_i^{\langle t \rangle}
Γi⟨t⟩。
单元状态 c ⟨ t ⟩ c^{\langle t \rangle} c⟨t⟩
- 单元状态是时间序列间传递的"记忆"
- 新的单元状态由之前的状态和当前候选值组成
公式:
c ⟨ t ⟩ = Γ f ⟨ t ⟩ ∗ c ⟨ t − 1 ⟩ + Γ i ⟨ t ⟩ ∗ c ~ ⟨ t ⟩ (4) \mathbf{c}^{\langle t \rangle} = \mathbf{\Gamma}_f^{\langle t \rangle}* \mathbf{c}^{\langle t-1 \rangle} + \mathbf{\Gamma}_{i}^{\langle t \rangle} *\mathbf{\tilde{c}}^{\langle t \rangle} \tag{4} c⟨t⟩=Γf⟨t⟩∗c⟨t−1⟩+Γi⟨t⟩∗c~⟨t⟩(4)
- 结合上面所有的公式,得到了单元状态的计算公式
- 前一单元状态由遗忘门控制会有多少被保存到当前单元状态中
- 候选值由更新门控制能有多少被保存到当前单元状态中
c
:所有时间点的单元状态
c
c
c,形状是
(
n
a
,
m
,
T
)
(n_a,m,T)
(na,m,T)
c_next
:当前单元状态
c
⟨
t
⟩
c^{\langle t \rangle}
c⟨t⟩,形状
(
n
a
,
m
)
(n_a,m)
(na,m)
c_prev
: 前一个单元状态
c
⟨
t
−
1
⟩
c^{\langle t-1 \rangle}
c⟨t−1⟩,形状
(
n
a
,
m
)
(n_a,m)
(na,m)
输出门 Γ o \mathbf{\Gamma}_{o} Γo
- 输出门控制了当前时间点能输出什么
- 和之前所有门一样,取值范围[0,1]
公式:
Γ
o
⟨
t
⟩
=
σ
(
W
o
[
a
⟨
t
−
1
⟩
,
x
⟨
t
⟩
]
+
b
o
)
(5)
\mathbf{\Gamma}_o^{\langle t \rangle}= \sigma(\mathbf{W}_o[\mathbf{a}^{\langle t-1 \rangle}, \mathbf{x}^{\langle t \rangle}] + \mathbf{b}_{o})\tag{5}
Γo⟨t⟩=σ(Wo[a⟨t−1⟩,x⟨t⟩]+bo)(5)
W_o
代表输出门的权重
W
o
W_o
Wo,bo
代表输出门的偏差
b
o
b_o
bo,ot
代表输出门
Γ
o
\mathbf{\Gamma}_{o}
Γo
从三个门的公式可以看出,它们的激活函数都是sigmoid,取值都是[0,1],输入都是 a ⟨ t − 1 ⟩ a^{\langle t-1 \rangle} a⟨t−1⟩和 x ⟨ t ⟩ x^{\langle t \rangle} x⟨t⟩,唯一的区别是可学习的权重和偏差不一样。如果取值为0,表示这个门是关闭的;取值为1,表示这个门是完全打开的;取值 ( 0 , 1 ) (0,1) (0,1)表示这个门是半关半开的,只允许一部分的值进入(被保存,被传递)。
隐藏状态
- 当前的隐藏状态会传递到下一个时间点的LSTM单元
- 它用于决定下个时间点的三个门
- 同时也用于当前时间点的预测(输出值 y ^ ⟨ t ⟩ \hat y^{\langle t \rangle} y^⟨t⟩)
公式:
a
⟨
t
⟩
=
Γ
o
⟨
t
⟩
∗
tanh
(
c
⟨
t
⟩
)
(6)
\mathbf{a}^{\langle t \rangle} = \mathbf{\Gamma}_o^{\langle t \rangle} * \tanh(\mathbf{c}^{\langle t \rangle})\tag{6}
a⟨t⟩=Γo⟨t⟩∗tanh(c⟨t⟩)(6)
- 隐藏状态由单元状态和输出门决定
- 单元状态传递到tanh函数得到 [ − 1 , 1 ] [-1,1] [−1,1]的取值
a
: 所有的隐藏状态
a
a
a,形状
(
n
a
,
m
,
T
x
)
(n_a,m,T_x)
(na,m,Tx)
a_prev
: 上个时间点的隐藏状态
a
⟨
t
−
1
⟩
a^{\langle t-1 \rangle}
a⟨t−1⟩,形状
(
n
a
,
m
)
(n_a,m)
(na,m)
a_next
: 当前时间点的隐藏状态
a
⟨
t
⟩
a^{\langle t \rangle}
a⟨t⟩,形状
(
n
a
,
m
)
(n_a,m)
(na,m)
预测值 y ^ ⟨ t ⟩ \hat y^{\langle t \rangle} y^⟨t⟩
- 在分类问题中的输出值使用softmax函数
z ⟨ t ⟩ = W y a ⟨ t ⟩ + b y (7) z^{\langle t \rangle} = \mathbf{W}_{y} \mathbf{a}^{\langle t \rangle} + \mathbf{b}_{y} \tag{7} z⟨t⟩=Wya⟨t⟩+by(7)
y p r e d ⟨ t ⟩ = softmax ( z ⟨ t ⟩ ) (8) \mathbf{y}^{\langle t \rangle}_{pred} = \textrm{softmax}(z^{\langle t \rangle}) \tag{8} ypred⟨t⟩=softmax(z⟨t⟩)(8)
y_pred
: 所有时间点的预测值
y
p
r
e
d
y_{pred}
ypred,形状
(
n
y
,
m
,
T
x
)
(n_y,m,T_x)
(ny,m,Tx)
yt_pred
: 当前时间点的预测值
y
p
r
e
d
⟨
t
⟩
y_{pred}^{\langle t \rangle}
ypred⟨t⟩,形状
(
n
y
,
m
)
(n_y,m)
(ny,m)
至此我们知道了LSTM单元中的所有计算公式,下面来看如何实现前向传播和反向传播。
前向传播
实现如上图所示的前向传播过程,我们需要代码化上面的公式 ( 1 ) (1) (1)~ ( 7 ) (7) (7)。
要注意的是,我们会叠加前一个隐藏状态
a
⟨
t
−
1
⟩
a^{\langle t-1 \rangle}
a⟨t−1⟩和当前的输入
x
⟨
t
⟩
x^{\langle t \rangle}
x⟨t⟩到一个矩阵concat
:
c o n c a t = [ a ⟨ t − 1 ⟩ x ⟨ t ⟩ ] concat = \begin{bmatrix} a^{\langle t-1 \rangle} \\ x^{\langle t \rangle} \end{bmatrix} concat=[a⟨t−1⟩x⟨t⟩]
反向传播
LSTM的反向传播比RNN的要复杂一点。不过遵循规则——求某个节点的梯度时,考虑该节点的所有输出节点。分别计算每个输出节点的梯度乘上输出节点对该节点的梯度,然后加起来就得到该节点的梯度,也不难。
首先列出激活函数的导数:
d
tanh
(
x
)
=
1
−
tanh
(
x
)
2
d \tanh(x) = 1 - \tanh(x)^2
dtanh(x)=1−tanh(x)2
d
σ
(
x
)
=
σ
(
x
)
(
1
−
σ
(
x
)
)
d \sigma(x) = \sigma(x)(1 - \sigma(x))
dσ(x)=σ(x)(1−σ(x))
假设考虑的LSTM结构为多对多的,且
T
x
=
T
y
T_x=T_y
Tx=Ty,每个时刻
t
t
t都有一个输出及一个损失
l
(
t
)
l(t)
l(t),全局损失函数为:
L
=
∑
i
=
1
T
x
l
(
t
)
(9)
L = \sum_{i=1}^{T_x} l(t) \tag{9}
L=i=1∑Txl(t)(9)
我们求 L L L对 z ⟨ t ⟩ z^{\langle t \rangle} z⟨t⟩的导数 d z ⟨ t ⟩ dz^{\langle t \rangle} dz⟨t⟩,具体过程可以参考博客 Softmax与Cross-entropy的求导,得到:
d z ⟨ t ⟩ = y ^ ⟨ t ⟩ − y ⟨ t ⟩ (10) dz^{\langle t \rangle} = \hat y^{\langle t \rangle} - y^{\langle t \rangle} \tag{10} dz⟨t⟩=y^⟨t⟩−y⟨t⟩(10)
根据公式 ( 7 ) (7) (7),可以很容易的求出:
d
W
y
=
d
z
⋅
a
⟨
t
⟩
(11)
dW_y = dz \cdot a^{\langle t \rangle} \tag{11}
dWy=dz⋅a⟨t⟩(11)
d
b
y
=
d
z
(12)
db_y = dz \tag{12}
dby=dz(12)
而求 d a da da和 d c dc dc时要分两种情况考虑:
在时刻
T
x
T_x
Tx时,
d
a
⟨
T
x
⟩
=
∂
L
∂
a
⟨
T
x
⟩
=
∂
l
(
T
x
)
∂
a
⟨
T
x
⟩
=
d
z
⟨
T
x
⟩
W
y
(13)
da^{\langle T_x \rangle} = \frac{\partial L}{\partial a^{\langle T_x \rangle}}= \frac{\partial l(T_x)}{\partial a^{\langle T_x \rangle}} = dz^{\langle T_x \rangle} W_y \tag{13}
da⟨Tx⟩=∂a⟨Tx⟩∂L=∂a⟨Tx⟩∂l(Tx)=dz⟨Tx⟩Wy(13)
d c ⟨ T x ⟩ = ∂ L ∂ c ⟨ T x ⟩ = d a ⟨ T x ⟩ ⋅ Γ o ⟨ T x ⟩ ⋅ ( 1 − t a n h ( c ⟨ T x ⟩ ) 2 ) (14) dc^{\langle T_x \rangle} = \frac{\partial L}{\partial c^{\langle T_x \rangle}} = da^{\langle T_x \rangle}\cdot \mathbf{\Gamma}_o^{\langle T_x \rangle} \cdot (1 - tanh(c^{\langle T_x \rangle})^2) \tag{14} dc⟨Tx⟩=∂c⟨Tx⟩∂L=da⟨Tx⟩⋅Γo⟨Tx⟩⋅(1−tanh(c⟨Tx⟩)2)(14)
在时刻 t ( t < T x ) t \,\,(t < T_x) t(t<Tx)时, a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩的后续同时有 a ⟨ t + 1 ⟩ a^{\langle t+1 \rangle} a⟨t+1⟩(大于 t t t时刻的误差)和 y ⟨ t ⟩ y^{\langle t \rangle} y⟨t⟩( t t t时刻的误差)两个节点。因此计算梯度时要考虑这两部分:
d a ⟨ t ⟩ = ∂ a ⟨ t + 1 ⟩ a ⟨ t ⟩ + ∂ l ( t ) ∂ a ⟨ t ⟩ = ∂ L ( t + 1 ) ∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ + d z ⟨ t ⟩ W y = d z ⟨ t + 1 ⟩ W y ⋅ ∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ + d z ⟨ t ⟩ W y (15) da^{\langle t \rangle} = \frac{\partial a^{\langle t+1 \rangle}}{a^{\langle t \rangle}} + \frac{\partial l(t)}{\partial a^{\langle t \rangle}} \\ = \frac{\partial L(t+1)}{\partial a^{\langle t+1 \rangle}}\frac{\partial a^{\langle t+1 \rangle}}{\partial a^{\langle t \rangle}} + dz^{\langle t \rangle} W_y \\ = dz^{\langle t+1 \rangle} W_y \cdot \frac{\partial a^{\langle t+1 \rangle}}{\partial a^{\langle t \rangle}} +dz^{\langle t \rangle} W_y \tag{15} da⟨t⟩=a⟨t⟩∂a⟨t+1⟩+∂a⟨t⟩∂l(t)=∂a⟨t+1⟩∂L(t+1)∂a⟨t⟩∂a⟨t+1⟩+dz⟨t⟩Wy=dz⟨t+1⟩Wy⋅∂a⟨t⟩∂a⟨t+1⟩+dz⟨t⟩Wy(15)
在这一步反向传播计算的难点在于 ∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ \frac{\partial a^{\langle t+1 \rangle}}{\partial a^{\langle t \rangle}} ∂a⟨t⟩∂a⟨t+1⟩。
因为 a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩受到上图这四部分所影响,而这四部分都和 a ⟨ t − 1 ⟩ a^{\langle t-1 \rangle} a⟨t−1⟩有关。所以 ∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ \frac{\partial a^{\langle t+1 \rangle}}{\partial a^{\langle t \rangle}} ∂a⟨t⟩∂a⟨t+1⟩的计算结果也由四部分组成(公式 ( 6 , 5 ) , ( 6 , 4 , 2 ) , ( 6 , 4 , 3 ) , ( 6 , 4 , 1 ) (6,5),(6,4,2),(6,4,3),(6,4,1) (6,5),(6,4,2),(6,4,3),(6,4,1)):
∂ a ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ = ∂ a ⟨ t + 1 ⟩ ∂ Γ o ⟨ t + 1 ⟩ ∂ Γ o ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ + ∂ a ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ c ~ ⟨ t + 1 ⟩ ∂ c ~ ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ + ∂ a ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ Γ i ⟨ t + 1 ⟩ ∂ Γ i ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ + ∂ a ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ c ⟨ t + 1 ⟩ ∂ Γ f ⟨ t + 1 ⟩ ∂ Γ f ⟨ t + 1 ⟩ ∂ a ⟨ t ⟩ = tanh ( c ⟨ t + 1 ⟩ ) ⋅ Γ o ⟨ t + 1 ⟩ ( 1 − Γ o ⟨ t + 1 ⟩ ) W o + Γ o ⟨ t + 1 ⟩ ( 1 − tanh ( c ⟨ t + 1 ⟩ ) 2 ) ⋅ Γ i ⟨ t + 1 ⟩ ⋅ ( 1 − c ~ ⟨ t ⟩ 2 ) W c + Γ o ⟨ t + 1 ⟩ ( 1 − tanh ( c ⟨ t + 1 ⟩ ) 2 ) ⋅ c ~ ⟨ t + 1 ⟩ ⋅ Γ i ⟨ t ⟩ ( 1 − Γ i ⟨ t ⟩ ) W i + Γ o ⟨ t + 1 ⟩ ( 1 − tanh ( c ⟨ t + 1 ⟩ ) 2 ) ⋅ c ⟨ t ⟩ ⋅ Γ f ⟨ t ⟩ ( 1 − Γ f ⟨ t ⟩ ) W f (16) \frac{\partial a^{\langle t+1 \rangle}}{\partial a^{\langle t \rangle}} = \frac{\partial a^{\langle t+1 \rangle}}{\partial \mathbf{\Gamma}_o^{\langle t+1 \rangle}} \frac{\partial \mathbf{\Gamma}_o^{\langle t+1 \rangle} }{\partial a^{\langle t \rangle} } + \frac{\partial a^{\langle t+1 \rangle}}{\partial c^{\langle t+1 \rangle}} \frac{\partial c^{\langle t+1 \rangle} }{\partial \mathbf{\tilde{c}}^{\langle t+1 \rangle} } \frac{\partial \mathbf{\tilde{c}}^{\langle t+1 \rangle} }{\partial a^{\langle t \rangle}} + \frac{\partial a^{\langle t+1 \rangle}}{\partial c^{\langle t+1 \rangle}} \frac{\partial c^{\langle t+1 \rangle} }{\partial \mathbf{\Gamma}_i^{\langle t+1 \rangle}} \frac{\partial \mathbf{\Gamma}_i^{\langle t+1 \rangle} }{\partial a^{\langle t \rangle}} + \frac{\partial a^{\langle t+1 \rangle}}{\partial c^{\langle t+1 \rangle}} \frac{\partial c^{\langle t+1 \rangle} }{\partial \mathbf{\Gamma}_f^{\langle t+1 \rangle}} \frac{\partial \mathbf{\Gamma}_f^{\langle t+1 \rangle} }{\partial a^{\langle t \rangle}} \\ = \tanh(c^{\langle t+1 \rangle}) \cdot \mathbf{\Gamma}_o^{\langle t+1 \rangle}(1-\mathbf{\Gamma}_o^{\langle t+1 \rangle})W_o + \mathbf{\Gamma}_o^{\langle t+1 \rangle}(1-{\tanh(c^{\langle t+1 \rangle})}^2)\cdot \mathbf{\Gamma}_i^{\langle t+1 \rangle}\cdot (1-{\mathbf{\tilde{c}}^{\langle t \rangle}}^2)W_c + \mathbf{\Gamma}_o^{\langle t+1 \rangle}(1-{\tanh(c^{\langle t+1 \rangle})}^2) \cdot \mathbf{\tilde{c}}^{\langle t+1 \rangle} \cdot \mathbf{\Gamma}_i^{\langle t \rangle}(1 - \mathbf{\Gamma}_i^{\langle t \rangle})W_i + \mathbf{\Gamma}_o^{\langle t+1 \rangle}(1-{\tanh(c^{\langle t+1 \rangle})}^2) \cdot c^{\langle t \rangle} \cdot \mathbf{\Gamma}_f^{\langle t \rangle}(1-\mathbf{\Gamma}_f^{\langle t \rangle})W_f \tag{16} ∂a⟨t⟩∂a⟨t+1⟩=∂Γo⟨t+1⟩∂a⟨t+1⟩∂a⟨t⟩∂Γo⟨t+1⟩+∂c⟨t+1⟩∂a⟨t+1⟩∂c~⟨t+1⟩∂c⟨t+1⟩∂a⟨t⟩∂c~⟨t+1⟩+∂c⟨t+1⟩∂a⟨t+1⟩∂Γi⟨t+1⟩∂c⟨t+1⟩∂a⟨t⟩∂Γi⟨t+1⟩+∂c⟨t+1⟩∂a⟨t+1⟩∂Γf⟨t+1⟩∂c⟨t+1⟩∂a⟨t⟩∂Γf⟨t+1⟩=tanh(c⟨t+1⟩)⋅Γo⟨t+1⟩(1−Γo⟨t+1⟩)Wo+Γo⟨t+1⟩(1−tanh(c⟨t+1⟩)2)⋅Γi⟨t+1⟩⋅(1−c~⟨t⟩2)Wc+Γo⟨t+1⟩(1−tanh(c⟨t+1⟩)2)⋅c~⟨t+1⟩⋅Γi⟨t⟩(1−Γi⟨t⟩)Wi+Γo⟨t+1⟩(1−tanh(c⟨t+1⟩)2)⋅c⟨t⟩⋅Γf⟨t⟩(1−Γf⟨t⟩)Wf(16)
上面有一个公共项 Γ o ⟨ t + 1 ⟩ ( 1 − tanh ( c ⟨ t + 1 ⟩ ) 2 ) \mathbf{\Gamma}_o^{\langle t+1 \rangle}(1-{\tanh(c^{\langle t+1 \rangle})}^2) Γo⟨t+1⟩(1−tanh(c⟨t+1⟩)2)
在时刻
t
(
t
<
T
x
)
t \,\,(t < T_x)
t(t<Tx)时,
c
⟨
t
⟩
c^{\langle t \rangle}
c⟨t⟩的梯度也是由当前时刻的误差以及
t
+
1
t+1
t+1时刻的误差组成(由公式
(
4
)
,
(
6
)
(4),(6)
(4),(6))得:
d
c
⟨
t
⟩
=
∂
L
∂
c
⟨
t
+
1
⟩
∂
c
⟨
t
+
1
⟩
∂
c
⟨
t
⟩
+
∂
L
∂
a
⟨
t
⟩
∂
a
⟨
t
⟩
∂
c
⟨
t
⟩
=
d
c
⟨
t
+
1
⟩
∂
c
⟨
t
+
1
⟩
∂
c
⟨
t
⟩
+
d
a
⟨
t
⟩
Γ
o
⟨
t
⟩
(
1
−
tanh
(
c
⟨
t
⟩
)
2
)
=
d
c
⟨
t
+
1
⟩
Γ
f
⟨
t
+
1
⟩
+
d
a
⟨
t
⟩
Γ
o
⟨
t
⟩
(
1
−
tanh
(
c
⟨
t
⟩
)
2
)
(17)
dc^{\langle t \rangle} = \frac{\partial L}{\partial c^{\langle t+1 \rangle}} \frac{\partial c^{\langle t+1 \rangle}}{\partial c^{\langle t \rangle}} + \frac{\partial L}{\partial a^{\langle t \rangle}} \frac{\partial a^{\langle t \rangle}}{\partial c^{\langle t \rangle}} \\ = dc^{\langle t+1 \rangle}\frac{\partial c^{\langle t+1 \rangle}}{\partial c^{\langle t \rangle}} + da^{\langle t \rangle}\mathbf{\Gamma}_o^{\langle t \rangle}(1-{\tanh(c^{\langle t \rangle})}^2) \\ = dc^{\langle t+1 \rangle}\mathbf{\Gamma}_f^{\langle t+1 \rangle} + da^{\langle t \rangle}\mathbf{\Gamma}_o^{\langle t \rangle}(1-{\tanh(c^{\langle t \rangle})}^2) \tag{17}
dc⟨t⟩=∂c⟨t+1⟩∂L∂c⟨t⟩∂c⟨t+1⟩+∂a⟨t⟩∂L∂c⟨t⟩∂a⟨t⟩=dc⟨t+1⟩∂c⟨t⟩∂c⟨t+1⟩+da⟨t⟩Γo⟨t⟩(1−tanh(c⟨t⟩)2)=dc⟨t+1⟩Γf⟨t+1⟩+da⟨t⟩Γo⟨t⟩(1−tanh(c⟨t⟩)2)(17)
现在求对
W
o
,
W
f
,
W
i
,
W
c
W_o,W_f,W_i,W_c
Wo,Wf,Wi,Wc的梯度就简单了。
d
W
o
=
∂
L
∂
a
⟨
t
⟩
⋅
∂
a
⟨
t
⟩
∂
Γ
o
⟨
t
⟩
⋅
∂
Γ
o
⟨
t
⟩
W
o
=
d
a
⟨
t
⟩
⋅
t
a
n
h
(
c
⟨
t
⟩
)
⋅
Γ
o
⟨
t
⟩
(
1
−
Γ
o
⟨
t
⟩
)
[
a
p
r
e
v
x
t
]
T
(18)
dW_o =\frac{\partial L}{\partial a^{\langle t \rangle}} \cdot \frac{\partial a^{\langle t \rangle}}{\partial \mathbf{\Gamma}_o^{\langle t \rangle}} \cdot \frac{\partial \mathbf{\Gamma}_o^{\langle t \rangle}}{W_o} = da^{\langle t \rangle} \cdot tanh(c^{\langle t \rangle}) \cdot \Gamma_o^{\langle t \rangle}(1-\Gamma_o^{\langle t \rangle}) \begin{bmatrix} a_{prev} \\ x_t\end{bmatrix}^T \tag{18}
dWo=∂a⟨t⟩∂L⋅∂Γo⟨t⟩∂a⟨t⟩⋅Wo∂Γo⟨t⟩=da⟨t⟩⋅tanh(c⟨t⟩)⋅Γo⟨t⟩(1−Γo⟨t⟩)[aprevxt]T(18)
d
b
o
=
d
a
⟨
t
⟩
⋅
∂
a
⟨
t
⟩
∂
Γ
o
⟨
t
⟩
⋅
∂
Γ
o
⟨
t
⟩
b
o
=
d
a
⟨
t
⟩
⋅
t
a
n
h
(
c
⟨
t
⟩
)
⋅
Γ
o
⟨
t
⟩
(
1
−
Γ
o
⟨
t
⟩
)
(19)
db_o = da^{\langle t \rangle} \cdot \frac{\partial a^{\langle t \rangle}}{\partial \mathbf{\Gamma}_o^{\langle t \rangle}} \cdot \frac{\partial \mathbf{\Gamma}_o^{\langle t \rangle}}{b_o} = da^{\langle t \rangle} \cdot tanh(c^{\langle t \rangle}) \cdot \Gamma_o^{\langle t \rangle}(1-\Gamma_o^{\langle t \rangle}) \tag{19}
dbo=da⟨t⟩⋅∂Γo⟨t⟩∂a⟨t⟩⋅bo∂Γo⟨t⟩=da⟨t⟩⋅tanh(c⟨t⟩)⋅Γo⟨t⟩(1−Γo⟨t⟩)(19)
d W f = ∂ L ∂ c ⟨ t ⟩ ⋅ ∂ c ⟨ t ⟩ ∂ Γ f ⟨ t ⟩ ⋅ ∂ Γ f ⟨ t ⟩ ∂ W f = d c ⟨ t ⟩ ⋅ c ⟨ t − 1 ⟩ ⋅ Γ f ⟨ t ⟩ ( 1 − Γ f ⟨ t ⟩ ) [ a p r e v x t ] T (20) dW_f = \frac{\partial L}{\partial c^{\langle t \rangle}} \cdot \frac{\partial c^{\langle t \rangle}}{\partial \Gamma_f^{\langle t \rangle}} \cdot \frac{\partial \Gamma_f^{\langle t \rangle}}{\partial W_f} \\ = dc^{\langle t \rangle} \cdot c^{\langle t-1 \rangle} \cdot \Gamma_f^{\langle t \rangle}(1-\Gamma_f^{\langle t \rangle})\begin{bmatrix} a_{prev} \\ x_t\end{bmatrix}^T \tag{20} dWf=∂c⟨t⟩∂L⋅∂Γf⟨t⟩∂c⟨t⟩⋅∂Wf∂Γf⟨t⟩=dc⟨t⟩⋅c⟨t−1⟩⋅Γf⟨t⟩(1−Γf⟨t⟩)[aprevxt]T(20)
d b f = d c ⟨ t ⟩ ⋅ c ⟨ t − 1 ⟩ ⋅ Γ f ⟨ t ⟩ ( 1 − Γ f ⟨ t ⟩ ) (21) db_f = dc^{\langle t \rangle} \cdot c^{\langle t-1 \rangle} \cdot \Gamma_f^{\langle t \rangle}(1-\Gamma_f^{\langle t \rangle}) \tag{21} dbf=dc⟨t⟩⋅c⟨t−1⟩⋅Γf⟨t⟩(1−Γf⟨t⟩)(21)
d
W
i
=
∂
L
∂
c
⟨
t
⟩
⋅
∂
c
⟨
t
⟩
∂
Γ
i
⟨
t
⟩
⋅
∂
Γ
i
⟨
t
⟩
∂
W
i
=
d
c
⟨
t
⟩
⋅
c
~
⟨
t
⟩
⋅
Γ
i
⟨
t
⟩
(
1
−
Γ
i
⟨
t
⟩
)
[
a
p
r
e
v
x
t
]
T
(22)
dW_i = \frac{\partial L}{\partial c^{\langle t \rangle}} \cdot \frac{\partial c^{\langle t \rangle}}{\partial \Gamma_i^{\langle t \rangle}} \cdot \frac{\partial \Gamma_i^{\langle t \rangle}}{\partial W_i} \\ = dc^{\langle t \rangle} \cdot \mathbf{\tilde{c}}^{\langle t \rangle} \cdot \Gamma_i^{\langle t \rangle}(1-\Gamma_i^{\langle t \rangle})\begin{bmatrix} a_{prev} \\ x_t\end{bmatrix}^T \tag{22}
dWi=∂c⟨t⟩∂L⋅∂Γi⟨t⟩∂c⟨t⟩⋅∂Wi∂Γi⟨t⟩=dc⟨t⟩⋅c~⟨t⟩⋅Γi⟨t⟩(1−Γi⟨t⟩)[aprevxt]T(22)
d
b
i
=
d
c
⟨
t
⟩
⋅
c
~
⟨
t
⟩
⋅
Γ
i
⟨
t
⟩
(
1
−
Γ
i
⟨
t
⟩
)
(23)
db_i = dc^{\langle t \rangle} \cdot \mathbf{\tilde{c}}^{\langle t \rangle} \cdot \Gamma_i^{\langle t \rangle}(1-\Gamma_i^{\langle t \rangle}) \tag{23}
dbi=dc⟨t⟩⋅c~⟨t⟩⋅Γi⟨t⟩(1−Γi⟨t⟩)(23)
d W c = ∂ L ∂ c ⟨ t ⟩ ⋅ ∂ c ⟨ t ⟩ ∂ c ~ ⟨ t ⟩ ⋅ ∂ c ~ ⟨ t ⟩ ∂ W c = d c ⟨ t ⟩ ⋅ Γ i ⟨ t ⟩ ⋅ ( 1 − c ~ ⟨ t ⟩ 2 ) [ a p r e v x t ] T (24) dW_c = \frac{\partial L}{\partial c^{\langle t \rangle}} \cdot \frac{\partial c^{\langle t \rangle}}{\partial \mathbf{\tilde{c}}^{\langle t \rangle}} \cdot \frac{\partial \mathbf{\tilde{c}}^{\langle t \rangle}}{\partial W_c} \\ = dc^{\langle t \rangle} \cdot \Gamma_i^{\langle t \rangle}\cdot (1-{\mathbf{\tilde{c}}^{\langle t \rangle}}^2) \begin{bmatrix} a_{prev} \\ x_t\end{bmatrix}^T \tag{24} dWc=∂c⟨t⟩∂L⋅∂c~⟨t⟩∂c⟨t⟩⋅∂Wc∂c~⟨t⟩=dc⟨t⟩⋅Γi⟨t⟩⋅(1−c~⟨t⟩2)[aprevxt]T(24)
d b c = d c ⟨ t ⟩ ⋅ Γ i ⟨ t ⟩ ⋅ ( 1 − c ~ ⟨ t ⟩ 2 ) (25) db_c = dc^{\langle t \rangle} \cdot \Gamma_i^{\langle t \rangle}\cdot (1-{\mathbf{\tilde{c}}^{\langle t \rangle}}^2) \tag{25} dbc=dc⟨t⟩⋅Γi⟨t⟩⋅(1−c~⟨t⟩2)(25)