通过时间反向传播
本文基于《动手学深度学习》一书,给出了对应章节相对详细的推导。
一、RNN的反向传播推导
1.问题描述
这是RNN网络的t时刻的关系式:
{
h
t
=
W
h
x
x
t
+
W
h
h
h
t
−
1
O
t
=
W
q
h
h
t
\left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right .
{ht=Whxxt+Whhht−1Ot=Wqhht
设有损失函数
L
=
1
T
∑
t
=
1
n
l
(
O
t
,
y
t
)
L = \frac{1}{T}\sum_{t=1}^{n}l(O_t, y_t)
L=T1t=1∑nl(Ot,yt)
欲求
∂
L
∂
W
q
h
,
∂
L
∂
W
h
x
,
∂
L
∂
W
h
h
\frac{\partial L}{\partial W_{qh}}, \frac{\partial L}{\partial W_{hx}}, \frac{\partial L}{\partial W_{hh}}
∂Wqh∂L,∂Whx∂L,∂Whh∂L
一些准备: 矩阵的链式求导和基本求导法则与原理是需要掌握的。
2.问题求解
首先,求解
∂
L
∂
W
q
h
\frac{\partial L}{\partial W_{qh}}
∂Wqh∂L
对于任意时刻
t
t
t ,显然有:
∂
L
∂
O
t
=
1
T
⋅
∂
l
(
O
t
,
y
t
)
∂
O
t
d
l
=
t
r
(
(
∂
l
∂
O
t
)
T
⋅
d
O
t
)
O
t
=
W
q
h
h
t
\frac{\partial L}{\partial O_t} = \frac{1}{T} \cdot \frac{\partial l(O_t, y_t)}{\partial O_t} \\ \mathrm{d}l = tr\left( {\left( \frac{\partial l}{\partial O_t} \right)}^T \cdot \mathrm{d}O_t \right) \\ O_t = W_{qh}h{t}
∂Ot∂L=T1⋅∂Ot∂l(Ot,yt)dl=tr((∂Ot∂l)T⋅dOt)Ot=Wqhht
因此,将
O
t
O_t
Ot 带入微分式中,有:
d
L
=
t
r
(
∑
i
=
1
T
(
∂
l
∂
O
t
)
T
d
W
q
h
⋅
h
t
)
\mathrm{d}L = tr\left( \sum_{i=1}^{T}{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \cdot h_t \right)
dL=tr(i=1∑T(∂Ot∂l)TdWqh⋅ht)
将
h
t
h_t
ht 放到迹的右方,有:
d
L
=
t
r
(
∑
i
=
1
T
h
t
(
∂
l
∂
O
t
)
T
d
W
q
h
)
\mathrm{d}L = tr\left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \right)
dL=tr(i=1∑Tht(∂Ot∂l)TdWqh)
因此:
∂
L
∂
W
q
h
=
(
∑
i
=
1
T
h
t
(
∂
l
∂
O
t
)
T
)
T
=
∑
i
=
1
T
∂
l
∂
O
t
(
h
t
)
T
\frac{\partial L}{\partial W_{qh}} = \left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \right)^T = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T
∂Wqh∂L=(i=1∑Tht(∂Ot∂l)T)T=i=1∑T∂Ot∂l(ht)T
接下来我们尝试求解
∂
L
∂
W
h
x
,
∂
L
∂
W
h
h
\frac{\partial L}{\partial W_{hx}},\frac{\partial L}{\partial W_{hh}}
∂Whx∂L,∂Whh∂L
先从T时刻开始求解(这里的prod()表示了矩阵链式求导的法则):
我们首先有:
{
h
t
=
W
h
x
x
t
+
W
h
h
h
t
−
1
O
t
=
W
q
h
h
t
\left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right .
{ht=Whxxt+Whhht−1Ot=Wqhht
∂
L
∂
h
T
=
p
r
o
d
(
∂
L
∂
O
T
,
∂
O
T
∂
h
T
)
\frac{\partial L}{\partial h_T} = prod\left( \frac{\partial L}{\partial O_T}, \frac{\partial O_T}{\partial h_T} \right)
∂hT∂L=prod(∂OT∂L,∂hT∂OT)
对于T-1时刻,有
∂
L
∂
h
T
−
1
=
p
r
o
d
(
∂
L
∂
O
T
−
1
,
∂
O
T
−
1
∂
h
T
−
1
)
+
p
r
o
d
(
∂
L
∂
h
T
,
∂
h
T
∂
h
T
−
1
)
\frac{\partial L}{\partial h_{T-1}} = prod\left( \frac{\partial L}{\partial O_{T-1}}, \frac{\partial O_{T-1}}{\partial h_{T-1}} \right) + prod\left( \frac{\partial L}{\partial h_T}, \frac{\partial h_T}{\partial h_{T-1}} \right)
∂hT−1∂L=prod(∂OT−1∂L,∂hT−1∂OT−1)+prod(∂hT∂L,∂hT−1∂hT)
…
同理,对于t时刻, t < T,有:
∂
L
∂
h
t
=
p
r
o
d
(
∂
L
∂
O
t
,
∂
O
t
∂
h
t
)
+
p
r
o
d
(
∂
L
∂
h
t
+
1
,
∂
h
t
+
1
∂
h
t
)
\frac{\partial L}{\partial h_t} = prod\left( \frac{\partial L}{\partial O_t}, \frac{\partial O_t}{\partial h_t} \right) + prod\left( \frac{\partial L}{\partial h_{t+1}}, \frac{\partial h_{t+1}}{\partial h_t} \right)
∂ht∂L=prod(∂Ot∂L,∂ht∂Ot)+prod(∂ht+1∂L,∂ht∂ht+1)
求偏导方式如上求解
∂
L
∂
W
q
h
\frac{\partial L}{\partial W_{qh}}
∂Wqh∂L 时使用的 化矩阵迹链式求导方法 所示,得到:
∂
L
∂
h
t
=
W
q
h
T
∂
L
∂
O
t
+
W
h
h
T
∂
L
∂
h
t
+
1
\frac{\partial L}{\partial h_t} = W_{qh}^T \frac{\partial L}{\partial O_t} + W_{hh}^T \frac{\partial L}{\partial h_{t+1}}
∂ht∂L=WqhT∂Ot∂L+WhhT∂ht+1∂L
打开该递归公式可得:
∂
L
∂
h
t
=
∑
i
=
t
T
(
W
h
h
T
)
T
−
i
W
q
h
T
∂
L
∂
O
T
+
t
−
i
\frac{\partial L}{\partial h_t} = \sum_{i=t}^T \left( W_{hh}^T \right)^{T-i} W_{qh}^T \frac{\partial L}{\partial O_{T+t-i}}
∂ht∂L=i=t∑T(WhhT)T−iWqhT∂OT+t−i∂L
所以
∂
L
∂
W
h
x
=
p
r
o
d
(
∂
L
∂
h
t
,
∂
h
t
∂
W
h
x
)
∂
L
∂
W
h
h
=
p
r
o
d
(
∂
L
∂
h
t
,
∂
h
t
∂
W
h
h
)
\frac{\partial L}{\partial W_{hx}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hx}} \right) \\ \frac{\partial L}{\partial W_{hh}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hh}} \right)
∂Whx∂L=prod(∂ht∂L,∂Whx∂ht)∂Whh∂L=prod(∂ht∂L,∂Whh∂ht)
继而有(此处的prod链式法则同上,请自行计算):
∂
L
∂
W
h
x
=
∑
t
=
1
T
∂
L
∂
h
t
x
t
T
∂
L
∂
W
h
h
=
∑
t
=
1
T
∂
L
∂
h
t
h
t
−
1
T
\frac{\partial L}{\partial W_{hx}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}x_t^T \\ \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}h_{t-1}^T
∂Whx∂L=t=1∑T∂ht∂LxtT∂Whh∂L=t=1∑T∂ht∂Lht−1T
再加上之前求解的:
∂
L
∂
W
q
h
=
∑
i
=
1
T
∂
l
∂
O
t
(
h
t
)
T
\frac{\partial L}{\partial W_{qh}} = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T
∂Wqh∂L=i=1∑T∂Ot∂l(ht)T
至此RNN的反向传播推导完毕。
二、LSTM的反向传播推导
1.问题描述
I t = σ ( W x i X t + W h i H t − 1 + b i ) F t = σ ( W x f X t + W h f H t − 1 + b f ) O t = σ ( W x o X t + W h o H t − 1 + b o ) C t ′ = t a n h ( W x c X t + W h c H t − 1 + b c ) C t = F t ⊙ C t − 1 + I t ⊙ C t ′ H t = O t ⊙ t a n h ( C t ) Y t = W q h H t + b q \begin{array}{ll} I_t=\sigma\left( W_{xi}X_t + W_{hi}H_{t-1} + b_i \right) \\ F_t=\sigma\left( W_{xf}X_t + W_{hf}H_{t-1} + b_f \right) \\ O_t=\sigma\left( W_{xo}X_t + W_{ho}H_{t-1} + b_o \right) \\ C_t^{'}=\mathrm{tanh}\left( W_{xc}X_t + W_{hc}H_{t-1} + b_{c} \right) \\ C_t=F_t \odot C_{t-1} + I_t \odot C_t^{'} \\ H_t=O_t \odot \mathrm{tanh}(C_t) \\ Y_t=W_{qh}H_t + b_q \end{array} It=σ(WxiXt+WhiHt−1+bi)Ft=σ(WxfXt+WhfHt−1+bf)Ot=σ(WxoXt+WhoHt−1+bo)Ct′=tanh(WxcXt+WhcHt−1+bc)Ct=Ft⊙Ct−1+It⊙Ct′Ht=Ot⊙tanh(Ct)Yt=WqhHt+bq