Vanishing Gradient and fancy RNNs
1 Vanishing Gradient(梯度消失)
RNN:
h
(
t
)
=
σ
(
W
h
h
(
t
−
1
)
+
W
x
x
(
t
)
)
y
^
(
t
)
=
s
o
f
t
m
a
x
(
W
s
h
(
t
)
)
\mathbf{h}^{(t)} = \sigma(\mathbf{W}_h \mathbf{h}^{(t-1)}+\mathbf{W}_x \mathbf{x}^{(t)}) \\ \mathbf{\hat{y}}^{(t)}=softmax(\mathbf{W}_s \mathbf{h}^{(t)})
h(t)=σ(Whh(t−1)+Wxx(t))y^(t)=softmax(Wsh(t))
其中,
x
(
t
)
∈
R
d
\mathbf{x}^{(t)}\in \mathbb{R}^d
x(t)∈Rd,
h
(
t
)
∈
R
D
h
\mathbf{h}^{(t)}\in \mathbb{R}^{D_h}
h(t)∈RDh,
W
x
∈
R
D
h
×
d
\mathbf{W}_x \in \mathbb{R}^{D_h\times d}
Wx∈RDh×d,
W
h
∈
R
D
h
×
D
h
\mathbf{W}_h \in \mathbb{R}^{D_h\times D_h}
Wh∈RDh×Dh,
y
^
(
t
)
∈
R
∣
V
∣
\mathbf{\hat{y}}^{(t)} \in \mathbb{R}^{|V|}
y^(t)∈R∣V∣代表对每个单词预测的概率,
W
s
∈
R
∣
V
∣
×
D
h
\mathbf{W}_s \in \mathbb{R}^{|V|\times D_h}
Ws∈R∣V∣×Dh,
∣
V
∣
|V|
∣V∣代表vocabulary大小,
σ
\sigma
σ是sigmoid函数。
timestep t处的loss为:
J
(
t
)
(
θ
)
=
−
∑
j
=
1
∣
V
∣
y
j
(
t
)
log
y
^
j
(
t
)
J^{(t)}(\theta) = -\sum_{j=1}^{|V|} \mathbf{y}^{(t)}_j \log \mathbf{\hat{y}}^{(t)}_j
J(t)(θ)=−j=1∑∣V∣yj(t)logy^j(t)
整个大小为T的语料库(corpus)上的loss:
J
=
1
T
∑
t
=
1
T
J
(
t
)
(
θ
)
=
−
1
T
∑
t
=
1
T
∑
j
=
1
∣
V
∣
y
j
(
t
)
log
y
^
j
(
t
)
J=\frac{1}{T} \sum_{t=1}^T J^{(t)}(\theta) = - \frac{1}{T} \sum_{t=1}^T \sum_{j=1}^{|V|} \mathbf{y}^{(t)}_j \log \mathbf{\hat{y}}^{(t)}_j
J=T1t=1∑TJ(t)(θ)=−T1t=1∑Tj=1∑∣V∣yj(t)logy^j(t)
perplexity为
e
J
e^J
eJ
总的loss对参数的导数:
∂
J
∂
W
=
∑
t
=
1
T
∂
J
(
t
)
∂
W
\frac{\partial J}{\partial \mathbf{W}}= \sum_{t=1}^T \frac{\partial J^{(t)}}{\partial \mathbf{W}}
∂W∂J=t=1∑T∂W∂J(t)
其中,
∂ J ( t ) ∂ W = ∂ J ( t ) ∂ y ^ ( t ) ∂ y ^ ( t ) ∂ h ( t ) ∑ k = 1 t ∂ h ( t ) ∂ h ( k ) ∂ h ( k ) ∂ W \frac{\partial J^{(t)}}{\partial \mathbf{W}} = \frac{\partial J^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \frac{\partial \mathbf{\hat{y}}^{(t)}}{\partial \mathbf{h}^{(t)}} \sum_{k=1}^t \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} \frac{\partial \mathbf{h}^{(k)}}{\partial \mathbf{W}} ∂W∂J(t)=∂y^(t)∂J(t)∂h(t)∂y^(t)k=1∑t∂h(k)∂h(t)∂W∂h(k)
∂ h ( t ) ∂ h ( k ) = ∏ j = k + 1 t W T d i a g ( σ ′ ( h ( j − 1 ) ) ) \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(k)}} = \prod_{j=k+1}^{t}\mathbf{W}^Tdiag(\sigma'(\mathbf{h}^{(j-1)})) ∂h(k)∂h(t)=j=k+1∏tWTdiag(σ′(h(j−1)))
∂
h
(
t
)
∂
h
(
t
−
1
)
=
d
i
a
g
(
σ
′
(
W
h
h
t
−
1
+
W
x
x
(
t
)
+
b
1
)
)
W
h
\frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(t-1)}} = diag(\sigma{'}(\mathbf{W}_h \mathbf{h}^{t-1}+\mathbf{W}_x \mathbf{x}^{(t)} + \mathbf{b}_1))\mathbf{W}_h
∂h(t−1)∂h(t)=diag(σ′(Whht−1+Wxx(t)+b1))Wh
第i步的loss
J
(
i
)
(
θ
)
J^{(i)}(\theta)
J(i)(θ)对他之前的第j步的隐层
h
(
j
)
\mathbf{h}^{(j)}
h(j)偏导数为:
∂
J
(
i
)
(
θ
)
∂
h
(
j
)
=
∂
J
(
i
)
(
θ
)
∂
h
(
i
)
∏
j
<
t
≤
i
∂
h
(
t
)
∂
h
(
t
−
1
)
=
∂
J
(
i
)
(
θ
)
∂
h
(
i
)
W
h
(
i
−
j
)
∏
j
<
t
≤
i
d
i
a
g
(
σ
′
(
W
h
h
t
−
1
+
W
x
x
(
t
)
+
b
1
)
)
\begin{aligned} \frac{\partial J^{(i)}(\theta)}{\partial \mathbf{h}^{(j)}} &=\frac{\partial J^{(i)}(\theta)}{\partial \mathbf{h}^{(i)}} \prod_{j<t\le i} \frac{\partial \mathbf{h}^{(t)}}{\partial \mathbf{h}^{(t-1)}}\\ &= \frac{\partial J^{(i)}(\theta)}{\partial \mathbf{h}^{(i)}} \mathbf{W}_h^{(i-j)}\prod_{j<t\le i}diag(\sigma{'}(\mathbf{W}_h \mathbf{h}^{t-1}+\mathbf{W}_x \mathbf{x}^{(t)} + \mathbf{b}_1)) \end{aligned}
∂h(j)∂J(i)(θ)=∂h(i)∂J(i)(θ)j<t≤i∏∂h(t−1)∂h(t)=∂h(i)∂J(i)(θ)Wh(i−j)j<t≤i∏diag(σ′(Whht−1+Wxx(t)+b1))
可以看出有一项
W
h
(
i
−
j
)
\mathbf{W}_h^{(i-j)}
Wh(i−j)是
W
h
\mathbf{W}_h
Wh指数,假如
W
h
\mathbf{W}_h
Wh比较小的化,随着i与j之间距离变长(即i-j变大),偏导数会指数的变小(vanishingly small)。
Proof Sketch
考虑矩阵L2范式(Matrix L2 Norms),有:
∥
∂
J
(
i
)
(
θ
)
∂
h
(
j
)
∥
≤
∥
∂
J
(
i
)
(
θ
)
∂
h
(
i
)
∥
∥
W
h
∥
(
i
−
j
)
∏
j
<
t
≤
i
∥
d
i
a
g
(
σ
′
(
W
h
h
t
−
1
+
W
x
x
(
t
)
+
b
1
)
)
∥
\left \| \frac{\partial J^{(i)}(\theta)}{\partial \mathbf{h}^{(j)}} \right \| \le \left \| \frac{\partial J^{(i)}(\theta)}{\partial \mathbf{h}^{(i)}} \right \| \left \| \mathbf{W}_h \right \|^{(i-j)} \prod_{j<t\le i} \left \| diag\left (\sigma{'}(\mathbf{W}_h \mathbf{h}^{t-1}+\mathbf{W}_x \mathbf{x}^{(t)} + \mathbf{b}_1) \right )\right \|
∥∥∥∥∂h(j)∂J(i)(θ)∥∥∥∥≤∥∥∥∥∂h(i)∂J(i)(θ)∥∥∥∥∥Wh∥(i−j)j<t≤i∏∥∥∥diag(σ′(Whht−1+Wxx(t)+b1))∥∥∥
[Pascanu et al](. http://proceedings.mlr.press/v28/pascanu13.pdf) 证明了如果 W h \mathbf{W}_h Wh最大的特征值小于1,那么梯度就会指数衰减。因为我们使用sigmoid,这里的 d i a g ( σ ′ ( W h h t − 1 + W x x ( t ) + b 1 ) ) diag(\sigma{'}(\mathbf{W}_h \mathbf{h}^{t-1}+\mathbf{W}_x \mathbf{x}^{(t)} + \mathbf{b}_1)) diag(σ′(Whht−1+Wxx(t)+b1))一个上界是1。
为什么梯度消失给我们带来困扰?
梯度可以被当作过去对未来的影响的度量。(Gradient can be viewed as measure of the effect of the past on the future)
由于梯度消失的问题,我们对于较远处无法计算梯度,模型的权重参数只会更新近处的,相当于丢失了长距离的信息。(So model weights are only updated only with respect to near effects, not long-term effects.),因此模型就学习不到长距离依赖(long-distance dependency)。
Exploding Gradient
类比梯度消失,梯度爆炸问题就是梯度增长的过大,导致Inf或者NaN问题。
针对梯度爆炸问题的一个解决方法是gradient clipping,方法是如果梯度的范式大于一定的阈值,就scale it down,相同方向但是更小的步长。
if
∥
g
^
∥
≥
t
h
r
e
s
h
o
l
d
then
g
^
←
t
h
r
e
s
h
o
l
d
∥
g
^
∥
g
^
end if
\begin{aligned} &\text{if}\quad\| \hat{g} \| \ge {threshold}\quad \text{then} \\ &\qquad \hat{g} \leftarrow\frac{threshold}{\| \hat{g} \|} \hat{g} \\ &\text{end if} \end{aligned}
if∥g^∥≥thresholdtheng^←∥g^∥thresholdg^end if
2 Long-Short Time Memory(LSTM)
RNN主要的问题就是无法传递长距离的信息。一个比较直观的想法就是将信息保存下来,这引出了一个RNN变体LSTM,由Hochreiter 和 Schmidhuber在1997年提出的,作为解决梯度消失问题的一个解决方案。
LSTM在时刻t加如了除隐层
h
(
t
)
h^{(t)}
h(t)外的状态cell state:
c
(
t
)
c^{(t)}
c(t),用于保存长距离信息,这里
h
(
t
)
h^{(t)}
h(t)和
c
(
t
)
c^{(t)}
c(t)大小都是长度为n的向量,此外定义了三种对cell state信息的操作:erase,write,read。这三种操作由相应的门(gate)控制,这些门与cell state大小一致,值在01之间,大小相当于对信息的保留程度,此外门是动态变化的,根据当前的context计算门的值。三个门formal的定义如下:
f
(
t
)
=
σ
(
W
f
h
(
t
−
1
)
+
U
f
x
(
t
)
+
b
f
)
Forget gate
i
(
t
)
=
σ
(
W
i
h
(
t
−
1
)
+
U
i
x
(
t
)
+
b
i
)
Input gate
o
(
t
)
=
σ
(
W
o
h
(
t
−
1
)
+
U
o
x
(
t
)
+
b
o
)
Output gate
\mathbf{f}^{(t)}=\sigma(\mathbf{W}_f\mathbf{h}^{(t-1)}+\mathbf{U}_f\mathbf{x}^{(t)}+\mathbf{b}_f) \qquad \text{Forget gate}\\ \mathbf{i}^{(t)}=\sigma(\mathbf{W}_i\mathbf{h}^{(t-1)}+\mathbf{U}_i\mathbf{x}^{(t)}+\mathbf{b}_i) \qquad \text{Input gate}\\ \mathbf{o}^{(t)}=\sigma(\mathbf{W}_o\mathbf{h}^{(t-1)}+\mathbf{U}_o\mathbf{x}^{(t)}+\mathbf{b}_o) \qquad \text{Output gate}\\
f(t)=σ(Wfh(t−1)+Ufx(t)+bf)Forget gatei(t)=σ(Wih(t−1)+Uix(t)+bi)Input gateo(t)=σ(Woh(t−1)+Uox(t)+bo)Output gate
Forget gate控制从上一个cell state中保存哪些信息舍弃哪些信息,Input gate控制把"new cell content"中哪些信息写入cell中,Output gate控制把cell中哪些信息输出到隐层中。下面给出利用上面的门计算cell state和hidden state:
c
~
(
t
)
=
t
a
n
h
(
W
c
h
(
t
−
1
)
+
U
c
x
(
t
)
+
b
c
)
New cell content
c
(
t
)
=
f
(
t
)
∘
c
(
t
−
1
)
+
i
(
t
)
∘
c
~
(
t
)
Cell state
h
(
t
)
=
o
(
t
)
∘
c
(
t
)
Hidden state
\begin{aligned} &\tilde{\mathbf{c}}^{(t)}=tanh(\mathbf{W}_c\mathbf{h}^{(t-1)}+\mathbf{U}_c\mathbf{x}^{(t)}+\mathbf{b}_c) \qquad &\text{New cell content}\\ &\mathbf{c}^{(t)}= \mathbf{f}^{(t)}\circ \mathbf{c}^{(t-1)}+ \mathbf{i}^{(t)} \circ \tilde{\mathbf{c}}^{(t)}\qquad &\text{Cell state}\\ &\mathbf{h}^{(t)}= \mathbf{o}^{(t)}\circ \mathbf{c}^{(t)}\qquad &\text{Hidden state}\\ \end{aligned}
c~(t)=tanh(Wch(t−1)+Ucx(t)+bc)c(t)=f(t)∘c(t−1)+i(t)∘c~(t)h(t)=o(t)∘c(t)New cell contentCell stateHidden state
其中new cell content:
c
~
(
t
)
\tilde{\mathbf{c}}^{(t)}
c~(t)代表要写入cell state的新内容,cell state使用forget gate忘记一部分上一时刻的cell state信息,使用input gate输入一部分新的内容,hidden state使用output gate控制哪部分cell state作为输出。其中上述所有向量都是相同大小,门的控制使用element-wise product:
∘
\circ
∘。某一时刻的LSTM单元示意图如下:
所以为什么LSTM能够解决梯度消失问题?这来自于LSTM的结构更容易保存很长步的信息,对比RNN循环使用的参数矩阵并不能存储太多信息。理解这些结构可以从highway,ResNet,denseNet等网络结构中得到一些启发-skip-connections,其实就相当于引入信息传递的捷径,以获取传递过程中较少的信息损失。LSTM并不能保证不出现梯度消失/爆炸问题,但是相比较RNN它确实能够学习到长距离依赖(long-distance depencies)。
3 Gated Recurrent Units(GRU)
2014年有Cho et al.提出的GRU,作为LSTM的一个简单的替换,在时刻t,只计算hidden state,去掉了cell state,门控如下:
u
(
t
)
=
σ
(
W
u
h
(
t
−
1
)
+
U
u
x
(
t
)
+
b
u
)
Update gate
r
(
t
)
=
σ
(
W
r
h
(
t
−
1
)
+
U
r
x
(
t
)
+
b
r
)
Reset gate
\mathbf{u}^{(t)}=\sigma(\mathbf{W}_u\mathbf{h}^{(t-1)}+\mathbf{U}_u\mathbf{x}^{(t)}+\mathbf{b}_u) \qquad \text{Update gate}\\ \mathbf{r}^{(t)}=\sigma(\mathbf{W}_r\mathbf{h}^{(t-1)}+\mathbf{U}_r\mathbf{x}^{(t)}+\mathbf{b}_r) \qquad \text{Reset gate}\\
u(t)=σ(Wuh(t−1)+Uux(t)+bu)Update gater(t)=σ(Wrh(t−1)+Urx(t)+br)Reset gate
update gate控制hidden state中哪些update哪些preserver,reset gate控制使用前一个hidden state的哪些内容来计算new hidden state content。
隐层更新如下:
h
~
(
t
)
=
t
a
n
h
(
W
h
(
r
(
t
)
∘
h
(
t
−
1
)
)
+
U
h
x
(
t
)
+
b
h
)
New hidden state content
h
(
t
)
=
(
1
−
u
(
t
)
)
∘
h
(
t
−
1
)
+
u
(
t
)
∘
h
~
(
t
)
Hidden state
\begin{aligned} &\tilde{\mathbf{h}}^{(t)}=tanh(\mathbf{W}_h (\mathbf{r}^{(t)}\circ \mathbf{h}^{(t-1)}) +\mathbf{U}_h\mathbf{x}^{(t)}+\mathbf{b}_h) \qquad &\text{New hidden state content}\\ &\mathbf{h}^{(t)}= (1-\mathbf{u}^{(t)})\circ \mathbf{h}^{(t-1)} + \mathbf{u}^{(t)}\circ \tilde{\mathbf{h}}^{(t)} \qquad &\text{Hidden state}\\ \end{aligned}
h~(t)=tanh(Wh(r(t)∘h(t−1))+Uhx(t)+bh)h(t)=(1−u(t))∘h(t−1)+u(t)∘h~(t)New hidden state contentHidden state
new hidden state content:
h
~
(
t
)
\tilde{\mathbf{h}}^{(t)}
h~(t)的计算通过reset gate:
r
(
t
)
\mathbf{r}^{(t)}
r(t)挑选前一hidden state的信息,hidden state的计算通过update gate同步控制从前一hidden state中获取多少信息,从new hidden state content中获取多少信息。
LSTM vs GRU
LSTM和GRU是最广泛使用的两个门控RNN变体,两个最大的不同是GRU计算比LSTM快一些,并且需要的参数更少,关于性能方面,没有证据表明LSTM好还是GRU好。
Rule of thumb: 使用LSTM,但是如果追求效率再换成GRU。
4 Bidirectional RNNs
很多任务需要获取一个单词周围的环境,而不是仅仅单向的,例如情感分析(sentiment analysis)任务,如果只依赖于前半部分句子的信息我们很难判断整个句子的情感,而传统RNN都是单向的,对于每个时刻隐层的计算,输入信息只有 h ( t − 1 ) \mathbf{h}^{(t-1)} h(t−1)和 x ( t ) \mathbf{x}^{(t)} x(t),意味着当前的计算获得不到后面的信息,双向RNNs提出就是为了解决这一问题,在时刻t隐层的计算会同时获取前一时刻的隐层和后一时刻的隐层信息。
时刻t双向RNNs计算如下:
Forward RNNs
h
→
(
t
)
=
R
N
N
F
W
(
h
→
(
t
−
1
)
,
x
(
t
)
)
Backward RNNs
h
←
(
t
)
=
R
N
N
B
W
(
h
←
(
t
−
1
)
,
x
(
t
)
)
Concatenated hidden state
h
(
t
)
=
[
h
→
(
t
)
;
h
←
(
t
)
]
\text{Forward RNNs} \qquad \overrightarrow{\mathbf{h}}^{(t)}={RNN}_{FW}(\overrightarrow{\mathbf{h}}^{(t-1)}, \mathbf{x}^{(t)}) \\ \text{Backward RNNs} \qquad \overleftarrow{\mathbf{h}}^{(t)}={RNN}_{BW}(\overleftarrow{\mathbf{h}}^{(t-1)}, \mathbf{x}^{(t)}) \\ \text{Concatenated hidden state} \qquad \mathbf{h}^{(t)}=[\overrightarrow{\mathbf{h}}^{(t)};\overleftarrow{\mathbf{h}}^{(t)}] \\
Forward RNNsh(t)=RNNFW(h(t−1),x(t))Backward RNNsh(t)=RNNBW(h(t−1),x(t))Concatenated hidden stateh(t)=[h(t);h(t)]
使用双向RNN的前提是我们能够获取整个句子的信息,例如在language model中就不适用,因为我们只能获取一个单词左边的内容。
5 Multi-layer RNNs
RNNs相当于在一个维度上很deep(时间维度),因为是在多个时刻计算,我们也可以通过叠加多个RNN层,使其在其他维度也deep起来,deep往往意味着表达能力强,所以通过这个方式以期望模型获取更复杂的表达能力,使其lowe的层计算lower-level的特征,higer的层计算higher-level的特征。multilayer RNNs有时也称为stacked RNNs。
如图,每个lower层的hidden state作为它上一层的输入。在实践中,multilayer RNN一般是性能优于普通RNN的,但是multilay RNN普遍没有像深层卷积网络和深层前馈网络那么deep,有些卷积网络都有上百层。2017年Britz et al发现在NMT任务中,encoder RNN最好的是在2到4层,decoder RNN最好的是4层。基于Transformer的网络如BERT可以达到24层。如果想让RNN更deep,可以采取skip-connections/dense-connections的技巧。