RNN Vanishing Gradient
RNN Forward Propagation:
h
(
t
)
=
σ
(
W
h
h
(
t
−
1
)
+
W
x
x
(
t
)
+
b
1
)
y
^
(
t
)
=
s
o
f
t
m
a
x
(
W
s
h
(
t
)
+
b
2
)
)
\mathbf{h}^{(t)} = \sigma(\mathbf{W}_h \mathbf{h}^{(t-1)}+\mathbf{W}_x \mathbf{x}^{(t)}+\mathbf{b}_1) \\ \mathbf{\hat{y}}^{(t)}=softmax(\mathbf{W}_s \mathbf{h}^{(t)}+\mathbf{b}_2))
h(t)=σ(Whh(t−1)+Wxx(t)+b1)y^(t)=softmax(Wsh(t)+b2))
其中,输入为d维向量
x
(
t
)
∈
R
d
\mathbf{x}^{(t)}\in \mathbb{R}^d
x(t)∈Rd,隐层是
D
h
D_h
Dh维向量
h
(
t
)
∈
R
D
h
\mathbf{h}^{(t)}\in \mathbb{R}^{D_h}
h(t)∈RDh,
y
^
(
t
)
∈
R
∣
V
∣
\mathbf{\hat{y}}^{(t)} \in \mathbb{R}^{|V|}
y^(t)∈R∣V∣代表对每个单词预测的概率,参数:
W
x
∈
R
D
h
×
d
\mathbf{W}_x \in \mathbb{R}^{D_h\times d}
Wx∈RDh×d,
W
h
∈
R
D
h
×
D
h
\mathbf{W}_h \in \mathbb{R}^{D_h\times D_h}
Wh∈RDh×Dh,
W
s
∈
R
∣
V
∣
×
D
h
\mathbf{W}_s \in \mathbb{R}^{|V|\times D_h}
Ws∈R∣V∣×Dh,
∣
V
∣
|V|
∣V∣代表vocabulary大小,
σ
\sigma
σ是sigmoid函数。
在时刻t处的损失(loss)为互熵损失(cross-entropy)形式,
y
(
t
)
∈
R
∣
V
∣
\mathbf{y}^{(t)} \in \mathbb{R}^{|V|}
y(t)∈R∣V∣代表真实概率分布的向量,一般为one-hot形式,即在正确单词位置概率值为1,其他位置概率值为0:
J
(
t
)
(
θ
)
=
−
∑
w
=
1
∣
V
∣
y
w
(
t
)
log
y
^
w
(
t
)
J^{(t)}(\theta) = -\sum_{w=1}^{|V|} y^{(t)}_w \log \hat{y}^{(t)}_w
J(t)(θ)=−w=1∑∣V∣yw(t)logy^w(t)
整个语料库(corpus)上的损失(loss)为,其中语料库大小为T:
J
=
1
T
∑
t
=
1
T
J
(
t
)
(
θ
)
=
−
1
T
∑
t
=
1
T
∑
w
=
1
∣
V
∣
y
w
(
t
)
log
y
^
w
(
t
)
J=\frac{1}{T} \sum_{t=1}^T J^{(t)}(\theta) = - \frac{1}{T} \sum_{t=1}^T \sum_{w=1}^{|V|} y^{(t)}_w \log \hat{y}^{(t)}_w
J=T1t=1∑TJ(t)(θ)=−T1t=1∑Tw=1∑∣V∣yw(t)logy^w(t)
RNN Backward Propagation
模型的参数总共有三个,这里以求解
W
h
∈
R
D
h
×
D
h
\mathbf{W}_h \in \mathbb{R}^{D_h\times D_h}
Wh∈RDh×Dh的梯度为例,根据multivariable chain rule,损失对参数的梯度为:
∂
J
∂
W
h
=
∑
t
=
1
T
∂
J
(
t
)
∂
W
h
=
∑
t
=
1
T
∑
k
=
1
t
∂
J
(
t
)
∂
y
^
(
t
)
∂
y
^
(
t
)
∂
h
(
t
)
∂
h
(
t
)
∂
h
(
k
)
∂
h
(
k
)
∂
W
h
\frac{\partial J}{\partial \mathbf{W}_h}= \sum_{t=1}^T \frac{\partial J^{(t)}}{\partial \mathbf{W}_h} = \sum_{t=1}^T \sum_{k=1}^t \frac{\partial J^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \frac{\partial \mathbf{\hat{y}}^{(t)}}{\partial \mathbf{h}^{(t)}} \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} \frac{\partial \mathbf{h}^{(k)}}{\partial \mathbf{W}_h}
∂Wh∂J=t=1∑T∂Wh∂J(t)=t=1∑Tk=1∑t∂y^(t)∂J(t)∂h(t)∂y^(t)∂h(k)∂h(t)∂Wh∂h(k)
其中 ∂ h ( t ) ∂ h ( k ) \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} ∂h(k)∂h(t)为,
∂ h ( t ) ∂ h ( k ) = ∏ j = k + 1 t ∂ h ( j ) ∂ h ( j − 1 ) = ∏ j = k + 1 t W h d i a g ( σ ′ ( h ( j − 1 ) ) ) \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} = \prod_{j=k+1}^{t}\frac{\partial \mathbf{h}^{(j)}}{\partial \mathbf{h}^{(j-1)}}=\prod_{j=k+1}^{t} \mathbf{W}_hdiag(\sigma'(\mathbf{h}^{(j-1)})) ∂h(k)∂h(t)=j=k+1∏t∂h(j−1)∂h(j)=j=k+1∏tWhdiag(σ′(h(j−1)))
带回得:
∂
J
∂
W
h
=
∑
t
=
1
T
∂
J
(
t
)
∂
W
h
=
∑
t
=
1
T
∑
k
=
1
t
∂
J
(
t
)
∂
y
^
(
t
)
∂
y
^
(
t
)
∂
h
(
t
)
∏
j
=
k
+
1
t
∂
h
(
j
)
∂
h
(
j
−
1
)
∂
h
(
k
)
∂
W
h
=
∑
t
=
1
T
∑
k
=
1
t
∂
J
(
t
)
∂
y
^
(
t
)
∂
y
^
(
t
)
∂
h
(
t
)
∏
j
=
k
+
1
t
W
h
d
i
a
g
(
σ
′
(
h
(
j
−
1
)
)
)
∂
h
(
k
)
∂
W
h
=
∑
t
=
1
T
∑
k
=
1
t
∂
J
(
t
)
∂
y
^
(
t
)
∂
y
^
(
t
)
∂
h
(
t
)
W
h
t
−
k
∏
j
=
k
+
1
t
d
i
a
g
(
σ
′
(
h
(
j
−
1
)
)
)
∂
h
(
k
)
∂
W
h
\begin{aligned}\frac{\partial J}{\partial \mathbf{W}_h}&= \sum_{t=1}^T \frac{\partial J^{(t)}}{\partial \mathbf{W}_h} \\&= \sum_{t=1}^T \sum_{k=1}^t\frac{\partial J^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \frac{\partial\mathbf{\hat{y}}^{(t)}}{\partial \mathbf{h}^{(t)}} \prod_{j=k+1}^{t}\frac{\partial \mathbf{h}^{(j)}}{\partial \mathbf{h}^{(j-1)}}\frac{\partial \mathbf{h}^{(k)}}{\partial \mathbf{W}_h} \\&= \sum_{t=1}^T \sum_{k=1}^t \frac{\partial J^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \frac{\partial \mathbf{\hat{y}}^{(t)}}{\partial \mathbf{h}^{(t)}} \prod_{j=k+1}^{t}\mathbf{W}_h diag\left (\sigma'(\mathbf{h}^{(j-1)}) \right)\frac{\partial \mathbf{h}^{(k)}}{\partial \mathbf{W}_h} \\&= \sum_{t=1}^T \sum_{k=1}^t \frac{\partial J^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \frac{\partial \mathbf{\hat{y}}^{(t)}}{\partial\mathbf{h}^{(t)}} \mathbf{W}_h^{t-k}\prod_{j=k+1}^{t}diag\left (\sigma'(\mathbf{h}^{(j-1)}) \right)\frac{\partial \mathbf{h}^{(k)}}{\partial \mathbf{W}_h} \\ \end{aligned}
∂Wh∂J=t=1∑T∂Wh∂J(t)=t=1∑Tk=1∑t∂y^(t)∂J(t)∂h(t)∂y^(t)j=k+1∏t∂h(j−1)∂h(j)∂Wh∂h(k)=t=1∑Tk=1∑t∂y^(t)∂J(t)∂h(t)∂y^(t)j=k+1∏tWhdiag(σ′(h(j−1)))∂Wh∂h(k)=t=1∑Tk=1∑t∂y^(t)∂J(t)∂h(t)∂y^(t)Wht−kj=k+1∏tdiag(σ′(h(j−1)))∂Wh∂h(k)
可见有一个指数项
W
h
t
−
k
\mathbf{W}_h^{t-k}
Wht−k,假如
W
h
\mathbf{W}_h
Wh比较小,随着t与k之间距离变长(即t-k变大),偏导数会指数的变小(vanishingly small)。实际上和矩阵的特征值有关,我们假设
W
h
\mathbf{W}_h
Wh特征值分解为
W
h
=
A
Λ
A
−
1
\mathbf{W}_h=A\Lambda A^{-1}
Wh=AΛA−1,其中对角矩阵
Λ
\Lambda
Λ内特征值按照绝对值大小排列(
∣
λ
1
∣
≥
∣
λ
2
∣
.
.
.
∣
λ
n
∣
|\lambda_1|\ge |\lambda_2|...|\lambda_{_n}|
∣λ1∣≥∣λ2∣...∣λn∣),则
W
h
t
−
k
=
(
A
Λ
A
−
1
)
.
.
.
(
A
Λ
A
−
1
)
=
A
Λ
t
−
k
A
−
1
=
λ
1
t
−
k
a
1
a
1
T
+
λ
2
t
−
k
a
2
a
2
T
+
.
.
.
+
λ
1
n
−
k
a
n
a
n
T
\mathbf{W}_h^{t-k}=(A\Lambda A^{-1})...(A\Lambda A^{-1})=A\Lambda^{t-k}A^{-1}=\lambda_1^{t-k}\mathbf{a}_1\mathbf{a}_1^T+\lambda_2^{t-k}\mathbf{a}_2\mathbf{a}_2^T+...+\lambda_1^{n-k}\mathbf{a}_n\mathbf{a}_n^T
Wht−k=(AΛA−1)...(AΛA−1)=AΛt−kA−1=λ1t−ka1a1T+λ2t−ka2a2T+...+λ1n−kananT,这样如果
∣
λ
1
∣
≤
1
|\lambda_1|\le 1
∣λ1∣≤1,那么在t-k很大的情况下,
W
h
t
−
k
\mathbf{W}_h^{t-k}
Wht−k近乎于0。
令
β
W
\beta_W
βW和
β
h
\beta_h
βh分别表示两个矩阵范式的上界(upper bound),即对于任意
W
\mathbf{W}
W和
d
i
a
g
(
σ
′
(
h
)
)
diag(\sigma'(\mathbf{h}))
diag(σ′(h)),都满足
∥
W
∥
≤
β
W
\| \mathbf{W} \| \le \beta_W
∥W∥≤βW,
∥
d
i
a
g
(
σ
′
(
h
)
)
∥
≤
β
h
\| diag(\sigma'(\mathbf{h})) \| \le \beta_h
∥diag(σ′(h))∥≤βh,对于矩阵的L2 范式,
∥
∂
h
(
j
)
∂
h
(
j
−
1
)
∥
=
∥
W
h
d
i
a
g
(
σ
′
(
h
(
t
−
1
)
)
)
∥
≤
∥
W
h
∥
∥
d
i
a
g
(
σ
′
(
h
(
t
−
1
)
)
)
∥
≤
β
W
β
h
\| \frac{\partial \mathbf{h}^{(j)}}{\partial \mathbf{h}^{(j-1)}} \| = \| \mathbf{W}_hdiag(\sigma'(\mathbf{h}^{(t-1)})) \| \le \| \mathbf{W}_h \| \|diag(\sigma'(\mathbf{h}^{(t-1)})) \| \le \beta_W \beta_h
∥∂h(j−1)∂h(j)∥=∥Whdiag(σ′(h(t−1)))∥≤∥Wh∥∥diag(σ′(h(t−1)))∥≤βWβh
所以
∥
∂
h
(
t
)
∂
h
(
k
)
∥
=
∏
j
=
k
+
1
t
∥
∂
h
(
j
)
∂
h
(
j
−
1
)
∥
≤
(
β
W
β
h
)
(
t
−
k
)
\| \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} \|= \prod_{j=k+1}^{t} \| \frac{\partial \mathbf{h}^{(j)}}{\partial \mathbf{h}^{(j-1)}} \| \le {(\beta_W\beta_h)}^{(t-k)}
∥∂h(k)∂h(t)∥=j=k+1∏t∥∂h(j−1)∂h(j)∥≤(βWβh)(t−k)
这样由于
(
β
W
β
h
)
(
t
−
k
)
{(\beta_W\beta_h)}^{(t-k)}
(βWβh)(t−k)这一项是指数形式,所以在t-k比较大的情况下,很容易变得非常大或者非常小。在RNN中我们使用的是sigmoid非线性函数,
∥
d
i
a
g
(
σ
′
(
h
)
)
∥
\| diag(\sigma'(\mathbf{h})) \|
∥diag(σ′(h))∥的一个上界是1,这样
(
β
W
β
h
)
(
t
−
k
)
{(\beta_W\beta_h)}^{(t-k)}
(βWβh)(t−k)在t-k较大时会vanishingly small,也就是梯度消失问题。