RNN和LSTM的理论推导
注:LSTM-BP待日后填坑
第一章 RNN的提出背景和作用
1.1RNN的提出背景
在传统的神经网络中,各个输入在算法内部是相对独立的,无法从先前的信息中进行推理,难以处理序列类型的数据。
1.2RNN的作用
在BP算法提出之后,学者和研究员们又提出了具有短期记忆能力的循环神经网络。RNN能记住之前的输入值,即使后面有相同的输入,输出值也会不同。目前已被广泛应用在语音识别,语言模型以及自然语言生成等任务上。
第二章 RNN的理论推导
2.1 RNN的网络结构
RNN的基本结构如下图所示,左侧是RNN网络,右侧是RNN网络按时序展开的形式。
RNN在传统神经网络输入层——隐藏层——输出层的基础上增加了一个类似于延时器的单元,能够记录上一次的输出值,并带到下一次的输入当中,以此记录最近几次的活性值V,RNN也由此具备了短期记忆能力。
将RNN按时序展开就可以得到RNN的结构
- 代表t时刻的输入
- 代表t时刻隐藏层的状态
- 代表t时刻的输出
- 代表输入层到隐藏层的权重
- V代表隐藏状态到下一隐藏状态的权重
- W代表隐藏层到输出层的权重
U,V,W是该模型中的线性关系参数,它在整个网络中是共享的,体现出RNN模型“循环反馈”的思想。
2.2 RNN的前向传播
RNN以时刻t为参数进行循环。虽然每一时刻的输入不同,但其对应的结构不变,所以每一次循环就相当于在进行递归,递推公式如下:
h t = σ ( x t U + h t − 1 V + b h ) h_{t}=\sigma(x_{t}U+h_{t-1}V+b_{h}) ht=σ(xtU+ht−1V+bh)
其中
σ
\sigma
σ为RNN的激活函数,一般为tanh。是该线性关系中的的偏置。
最终输出的表达式为:
y
t
^
=
φ
(
W
h
t
+
b
y
)
\hat{y_{t}}=\varphi(Wh_{t}+b_{y})
yt^=φ(Wht+by)
激活函数
φ
\varphi
φ一般是softmax。是该线性关系中的的偏置。
2.3 RNN的反向传播
类似于传统神经网络,RNN神经网络的反向传播算法思路也是通过梯度下降法一轮轮的选代。由于是基于时间的反向传播,所以将RNN神经网络的反向传播命名为BPTT(back-propagation through time)。我们利用反向传播算法将输出层的误差加和,然后对各个权重的参数矩阵求梯度,再利用梯度下降法更新各个权重。
对于每一时刻t的RNN网络,网络的输出在每个时刻都会产生损失。那么总的损失为。我们的目标就是要求取对的偏导。
对于预测结果的任意损失函数,求取是最简单的,我们可以直接求取每个时刻的,由于它不存在和之前的依赖状态,可以直接求导取得,然后简单求和即可,算式如下:
∂
l
o
s
s
∂
W
=
∑
t
=
1
T
∂
L
t
∂
y
t
^
⋅
∂
y
t
^
∂
W
\frac{\partial loss}{\partial W}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial W}}
∂W∂loss=t=1∑T∂yt^∂Lt⋅∂W∂yt^
对于V的计算不能直接求导,因此需要用链式求导法则。
以对V求梯度为例
∂
l
o
s
s
∂
V
=
∑
t
=
1
T
∂
L
t
∂
V
=
∑
t
=
1
T
∂
L
t
∂
y
t
^
⋅
∂
y
t
^
∂
h
t
⋅
∂
h
t
∂
V
(
1
)
\frac{\partial loss}{\partial V}=\sum_{t=1}^{T}\mathord{\frac{\partial L_{t}}{\partial V}}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\cdot\frac{\partial h_{t}}{\partial V}} \quad(1)
∂V∂loss=t=1∑T∂V∂Lt=t=1∑T∂yt^∂Lt⋅∂ht∂yt^⋅∂V∂ht(1)
由公式
h
t
=
σ
(
x
t
U
+
h
t
−
1
V
+
b
h
)
h_{t}=\sigma(x_{t}U+h_{t-1}V+b_{h})
ht=σ(xtU+ht−1V+bh)对
∂
h
t
∂
V
\frac{\partial h_{t}}{\partial V}
∂V∂ht单独进行展开,可得
∂
h
t
∂
V
=
∑
k
=
1
t
∂
h
t
∂
h
k
⋅
∂
h
k
∂
V
=
∑
k
=
1
t
(
∏
j
=
k
+
1
t
h
j
h
j
−
1
)
⋅
∂
h
t
∂
V
(
2
)
\frac{\partial h_{t}}{\partial V}=\sum_{k=1}^{t}\mathop{\frac{\partial{h_{t}}}{\partial {h_k}}\cdot\frac{\partial h_{k}}{\partial V}} =\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial V}} \quad(2)
∂V∂ht=k=1∑t∂hk∂ht⋅∂V∂hk=k=1∑t(j=k+1∏thj−1hj)⋅∂V∂ht(2)
将(2)式代入(1)式,得
∂
l
o
s
s
∂
V
=
∑
t
=
1
T
∂
L
t
∂
y
t
^
⋅
∂
y
t
^
∂
h
t
∑
k
=
1
t
(
∏
j
=
k
+
1
t
h
j
h
j
−
1
)
⋅
∂
h
t
∂
V
\frac{\partial loss}{\partial V}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial V}}}
∂V∂loss=t=1∑T∂yt^∂Lt⋅∂ht∂yt^k=1∑t(j=k+1∏thj−1hj)⋅∂V∂ht
同理可得对U的梯度
∂
l
o
s
s
∂
U
=
∑
t
=
1
T
∂
L
t
∂
y
t
^
⋅
∂
y
t
^
∂
h
t
∑
k
=
1
t
(
∏
j
=
k
+
1
t
h
j
h
j
−
1
)
⋅
∂
h
t
∂
U
\frac{\partial loss}{\partial U}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial U}}}
∂U∂loss=t=1∑T∂yt^∂Lt⋅∂ht∂yt^k=1∑t(j=k+1∏thj−1hj)⋅∂U∂ht
第三章 传统RNN模型的缺陷和LSTM的提出
3.1 RNN的梯度消失和梯度爆炸
在计算t时刻损失产生的梯度时,必须回溯之前所有时刻的信息。
但是我们会发现一个问题,即最后要对所有时刻的梯度进行累加。而每个时刻都是在后一个时刻的基础上进行累乘的结果。
若累乘的次数过于庞大,每次都连续乘一个小于1的数字,就会导致最终结果趋近于0,即梯度消失。
反之,当每次都连续乘一个大于1的数字,就会导致最终的结果趋近于无穷,即梯度爆炸。
3.2 RNN缺陷的解决方案
为了克服梯度爆炸和梯度消失问题,最直观的想法就是使 h j h j − 1 = 1 \frac{h_{j}}{h_{j-1}}=1 hj−1hj=1。梯度裁剪,通过把沿梯度下降方向的步长限制在一个范围内,解决了梯度爆炸的问题,但梯度消失的问题仍难以解决。1997年,Hochreiter和Schmidhuber首先提出了LSTM的网络结构,通过CEC(constant error carrousel)单元,控制其结果为0或接近于1,解决了传统RNN的这一缺陷。
第四章 LSTM的理论推导
4.1 LSTM的基本结构
LSTM的基本结构如下图所示:
LSTM由遗忘门、输入门、输出门和细胞状态组成。
4.2 LSTM的遗忘门
遗忘门能够以一定的概率控制是否遗忘上一层的细胞状态。其数学表达式为:
f
t
=
σ
(
W
f
⋅
[
x
t
,
h
t
−
1
]
+
b
f
)
f_{t}=\sigma(W_{f}\cdot [x_{t},h_{t-1}]+b_{f})
ft=σ(Wf⋅[xt,ht−1]+bf)
其中激活函数
σ
\sigma
σ为sigmoid,为该线性关系中的偏置。上一时刻的输出和本时刻的输入作为该函数的输入,通过激活函数,得到遗忘门。
sigmoid函数的输出在[0,1]之间,表示细胞状态中被保留信息的概率值。输出越接近0,代表被遗忘的信息越多;输出越接近于1,代表被遗忘的信息越少。
4.3 LSTM的输入门
输入门由两部分组成,能处理当前时刻的输入,进行数据加强,并对细胞状态进行更新。
第一部分会进行两个操作。
1.为忽略因子,能够决定删除哪些部分,其数学表达式为:
i
t
=
σ
(
W
i
⋅
[
x
t
,
h
t
−
1
]
+
b
i
)
i_{t}=\sigma(W_{i}\cdot [x_{t},h_{t-1}]+b_{i})
it=σ(Wi⋅[xt,ht−1]+bi)
2.使用激活函数tanh创建了一个新的输入,其数学表达式为:
a
t
=
t
a
n
h
(
W
f
⋅
[
x
t
,
h
t
−
1
]
+
b
a
)
a_{t}=tanh(W_{f}\cdot [x_{t},h_{t-1}]+b_{a})
at=tanh(Wf⋅[xt,ht−1]+ba)
第二部分,对细胞状态进行更新。将和相乘,遗忘掉不必要的信息,再将新的输入值与,完成从到的更新,其数学表达式为:
C
t
=
f
t
⊙
C
t
−
1
+
i
t
⊙
a
t
C_{t}=f_{t} \odot C_{t-1}+i_{t} \odot a_{t}
Ct=ft⊙Ct−1+it⊙at
4.4 LSTM的输出门
输出门会基于刚更新的细胞状态进行输出,其表达式为
o
t
=
σ
(
W
o
⋅
[
x
t
,
h
t
−
1
]
+
b
o
)
o_{t}=\sigma(W_{o}\cdot [x_{t},h_{t-1}]+b_{o})
ot=σ(Wo⋅[xt,ht−1]+bo)
h
t
=
o
t
⊙
t
a
n
h
(
C
t
)
h_{t}=o_{t} \odot tanh(C_{t})
ht=ot⊙tanh(Ct)
输出门由上一时刻的输出和本时刻的输入,以及细胞状态两部分组成。首先使用sigmoid函数求得,来确定输出细胞状态的哪一部分。再通过tanh函数处理细胞状态,并与,输出这一时刻的结果
最后再进行预测输出
y
t
^
=
σ
(
V
⋅
h
t
]
+
b
y
)
\hat{y_{t}}=\sigma(V\cdot h_{t}]+b_{y})
yt^=σ(V⋅ht]+by)
LSTM的反向传播
待填坑