一、循环神经网络
传统的神经网络并不能做到保持信息的持久性,RNN(Recurrent Neural Retwork) 解决了这个问题。RNN 是包含循环的网络,允许信息的持久化。
在上面的示例图中,神经网络的模块,,正在读取某个输入 ,并输出一个值 。循环可以使得信息可以从当前步传递到下一步。
RNN 可以被看做是同一神经网络的多次复制,每个神经网络模块会把消息传递给下一个。所以,如果我们将这个循环展开:
RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,例如使用过去的视频段来推测对当前段的理解。
但是LSTM存在长期依赖(Long-Term Dependencies)问题
,也就是当序列特别长的时候,RNN 会丢失之前的信息。于是提出了LSTM来改进这些问题。
同时RNN还存在梯度消失和梯度爆炸
的问题。可以通过改变激活函数来解决,同时LSTM也可以解决梯度消失和梯度爆炸的问题。
1.1 标准RNN的前向输出流程
下面这是一个RNN详细的结构图,其中各个符号的含义:x是输入,h是隐层单元,o为输出
,y为训练集的标签
,L为损失函数。这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
前向传播算法其实非常简单,t时刻隐状态
h
(
t
)
h^{(t)}
h(t)为:
h
(
t
)
=
ϕ
(
U
x
(
t
)
+
W
h
(
t
−
1
)
+
b
)
h^{(t)}=\phi(Ux^{(t)}+Wh^{(t-1)}+b)
h(t)=ϕ(Ux(t)+Wh(t−1)+b)
其中
ϕ
(
)
\phi()
ϕ()为激活函数,一般来说会选择tanh函数
(注意这个tanh函数,它是引起梯度消失和梯度爆炸的原因,下面会细讲),b为偏置。
t时刻的输出
o
(
t
)
o^{(t)}
o(t)就更为简单(c为偏置):
o
(
t
)
=
V
h
(
t
)
+
c
o^{(t)}=Vh^{(t)}+c
o(t)=Vh(t)+c
t时刻模型的预测输出
y
(
t
)
y^{(t)}
y(t)为:
y
(
t
)
=
σ
(
o
(
t
)
)
y^{(t)}=\sigma(o^{(t)})
y(t)=σ(o(t))
其中
σ
(
)
\sigma()
σ()为激活函数,通常RNN用于分类,故这里一般用softmax函数。
1.2 RNN的训练方法—BPTT
BPTT(back-propagation through time)算法是常用的训练RNN的方法,其实本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播
,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。综上所述,BPTT算法本质还是BP算法,BP算法本质还是梯度下降法,那么求各个参数的梯度便成了此算法的核心。
再次拿出这个结构图观察,需要寻优的参数有三个,分别是U、V、W
。与BP算法不同的是,其中W和U两个参数的寻优过程需要追溯之前的历史数据
,参数V相对简单只需关注目前时刻t。
(1)那么我们就来先求解参数V的偏导数:
∂
L
(
t
)
∂
V
=
∂
L
(
t
)
∂
o
(
t
)
⋅
∂
o
(
t
)
∂
V
\frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V}
∂V∂L(t)=∂o(t)∂L(t)⋅∂V∂o(t)
这个式子看起来简单但是求解起来很容易出错,因为其中嵌套着激活函数函数,是复合函数的求道过程。RNN的损失也是会随着时间累加的,所以不能只求t时刻的偏导,要把所有时刻的偏导都求出来再累加
:
L
=
∑
t
=
1
n
L
(
t
)
L=\sum_{t=1}^n L^{(t)}
L=∑t=1nL(t)
∂
L
∂
V
=
∑
t
=
1
n
∂
L
(
t
)
∂
o
(
t
)
⋅
∂
o
(
t
)
∂
V
\frac{\partial L}{\partial V}=\sum_{t=1}^n \frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V}
∂V∂L=∑t=1n∂o(t)∂L(t)⋅∂V∂o(t)
(2)W和U的偏导的求解由于需要涉及到历史数据,其偏导求起来相对复杂,我们先假设只有三个时刻
那么在第三个时刻
L对W的偏导数为:
∂
L
(
3
)
∂
W
=
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
W
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
W
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
h
(
1
)
∂
h
(
1
)
∂
W
\frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}
∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂W∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂W∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂W∂h(1)
相应的,L在第三个时刻
对U的偏导数为:
∂
L
(
3
)
∂
U
=
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
U
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
U
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
h
(
1
)
∂
h
(
1
)
∂
U
\frac{\partial L^{(3)}}{\partial U}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U}
∂U∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂U∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂U∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂U∂h(1)
可以观察到,在某个时刻的对W或是U的偏导数,需要追溯这个时刻之前所有时刻的信息
,这还仅仅是一个时刻的偏导数,上面说过损失也是会累加的,那么整个损失函数对W和U的偏导数将会非常繁琐。虽然如此但好在规律还是有迹可循,我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式
:
∂
L
(
t
)
∂
W
=
∑
k
=
0
t
∂
L
(
t
)
∂
o
(
t
)
∂
o
(
t
)
∂
h
(
t
)
(
∏
j
=
k
+
1
t
∂
h
(
j
)
∂
h
(
j
−
1
)
)
∂
h
(
k
)
∂
W
\frac{\partial L^{(t)}}{\partial W}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial W}
∂W∂L(t)=k=0∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂W∂h(k)
∂
L
(
t
)
∂
U
=
∑
k
=
0
t
∂
L
(
t
)
∂
o
(
t
)
∂
o
(
t
)
∂
h
(
t
)
(
∏
j
=
k
+
1
t
∂
h
(
j
)
∂
h
(
j
−
1
)
)
∂
h
(
k
)
∂
U
\frac{\partial L^{(t)}}{\partial U}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial U}
∂U∂L(t)=k=0∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂U∂h(k)
整体的偏导公式就是将其按时刻再一一加起来。注意这个累乘里面是t时刻的h对t-1时刻的h求导。
1.3 梯度消失和梯度爆炸
前面说过激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
=
∏
j
=
k
+
1
t
t
a
n
h
′
⋅
W
s
\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s}
∏j=k+1t∂hj−1∂hj=∏j=k+1ttanh′⋅Ws
或是
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
=
∏
j
=
k
+
1
t
s
i
g
m
o
i
d
′
⋅
W
s
\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{sigmoid^{'}}\cdot W_{s}
∏j=k+1t∂hj−1∂hj=∏j=k+1tsigmoid′⋅Ws
于是这个tanh的导数(或者sigmod的导数)就以累乘
的形式参与到梯度的计算中去。但是我们来看看tanh的导数和sigmod的导数的特征:
-
tanh的函数图像和导数图像:
-
sigmoid的函数图像和导数图像:
它们二者是何其的相似,都把输出压缩在了一个范围之内。他们的导数图像也非常相近,我们可以从中观察到,sigmoid函数的导数范围是(0,0.25],tanh函数的导数范围是(0,1],他们的导数最大都不大于1。这就是会带来几个问题:
(1)如果
W
s
W_{s}
Ws 也是一个大于0小于1的值,使得
t
a
n
h
′
∗
W
s
<
1
tanh' * W_s < 1
tanh′∗Ws<1,则当t很大时,梯度累乘的值就会趋近于0,和 (0.9*0.8)^50趋近与0是一个道理。
(2)同理当
W
s
W_{s}
Ws 很大时,具体指(比如
t
a
n
h
′
=
0.1
tanh' = 0.1
tanh′=0.1,而
W
s
=
99
W_s=99
Ws=99,则相乘为9.9),使得
t
a
n
h
′
∗
W
s
>
1
tanh' * W_s > 1
tanh′∗Ws>1,则当t很大时,梯度累乘的值就会趋近于无穷。
这就是RNN中梯度消失和爆炸
的原因。其实RNN的时间序列与深层神经网络很像,在较为深层的神经网络中使用sigmoid函数做激活函数也会导致反向传播时梯度消失
,梯度消失就意味消失那一层的参数再也不更新,那么那一层隐层就变成了单纯的映射层,毫无意义了,所以在深层神经网络中,有时候多加神经元数量可能会比多家深度好。但是tanh函数相对于sigmoid函数来说梯度较大,收敛速度更快且引起梯度消失更慢
。
sigmoid函数还有一个缺点,Sigmoid函数输出不是零中心对称
。sigmoid的输出均大于0,这就使得输出不是0均值,称为偏移现象
,这将导致后一层的神经元将上一层输出的非0均值的信号作为输入。而关于原点对称的输入和中心对称的输出,网络会收敛地更好
。
RNN的特点本来就是能“追根溯源“利用历史数据,现在告诉我可利用的历史数据竟然是有限的,这就令人非常难受,解决“梯度消失“是非常必要的。解决“梯度消失“的方法主要有:
- 选取更好的激活函数
- 改变传播结构
关于第一点,一般选用ReLU函数作为激活函数,ReLU函数的图像为:
左侧恒为1的导数避免了“梯度消失“的发生。但是容易导致“梯度爆炸“,设定合适的阈值
可以解决这个问题。
但是如果左侧横为0的导数有可能导致把神经元学死,出现这个原因可能是因为学习率太大,导致w更新巨大,使得输入的所有训练样本数据
在经过这个神经元的时候,所有输出值都小于0
,从而经过激活函数Relu计算之后的输出为0,从此不梯度(所有梯度之和
)再更新。所以relu为激活函数,学习率不能太大,设置合适的步长(学习率)也可以有效避免这个问题的发生。
二、长短期记忆神经网络
LSTM(Long short-term Memory):a very special kind of Recurrent Neural Retwork.长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现
。
2.1 LSTM 内部结构
上面介绍的RNN可以用下图表示,内部只有一个tanh激活函数:
LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于单一神经网络层,整体上除了h在随时间流动,细胞状态c也在随时间流动,细胞状态c就代表着长期记忆
:
现在,我们先来熟悉一下图中使用的各种元素的图标:
- 黄色的矩形是学习得到的神经网络层
- 粉色的圆形表示一些运算操作,诸如加法乘法
- 黑色的单箭头表示向量的传输
- 两个箭头合成一个表示向量的连接
- 一个箭头分开表示向量的复制
LSTM 的关键就是细胞状态,水平线在图上方贯穿运行。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易:
LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作:
LSTM 拥有三个门,来保护和控制细胞状态。
2.2 分步理解LSTM
LSTM内部主要有三个阶段:
(1) 忘记阶段
。
这个阶段主要是对上一个节点传进来的
C
t
−
1
C_{t-1}
Ct−1进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。具体来说是通过计算得到的 ft(f表示forget,也就是下面的ft)来作为忘记门控
,来控制上一个状态的
C
t
−
1
Ct-1
Ct−1 哪些需要留哪些需要忘。输出的ft是一个在 0 到 1 之间的数值
,描述每个部分有多少量可以通过。0 代表“不许任何量通过”,1 就指“允许任意量通过”,小数就是以前百分之多少的内容记住。然后这个ft和
C
t
−
1
C_{t-1}
Ct−1进行一个pointwise 乘法操作,从而达到遗忘的效果。
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
f_t=\sigma(W_f \cdot [h_{t-1},x_t]+b_f)
ft=σ(Wf⋅[ht−1,xt]+bf)
(2) 选择记忆阶段
。
将当前这个t阶段的输入
x
t
x_t
xt有选择性地进行“记忆”到细胞
C
t
C_t
Ct中。(上一步是对前一个输入
C
t
−
1
C_{t-1}
Ct−1进行选择记忆)。哪些重要则着重记录下来,哪些不重要,则少记一些,这里的it充当了一个记忆门控
的作用。要记住的内容暂存为
C
~
t
\tilde C_t
C~t:
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
i_t=\sigma(W_i \cdot [h_{t-1},x_t]+b_i)
it=σ(Wi⋅[ht−1,xt]+bi)
C
~
t
=
t
a
n
h
(
W
c
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
\tilde C_t=tanh(W_c \cdot [h_{t-1},x_t]+b_C)
C~t=tanh(Wc⋅[ht−1,xt]+bC)
具体记忆的方法如下:先计算
i
t
i_t
it和
C
~
t
\tilde C_t
C~t,再将两者相乘。
经过(1)(2)两个步骤之后,我们就可以更新细胞状态
C
t
C_t
Ct了。我们有了要从
C
t
−
1
C_{t-1}
Ct−1遗忘的和要从
x
t
x_t
xt记住的内容,显而易见,把两个内容相加就是更新之后的细胞状态
C
t
C_t
Ct:
C
t
=
f
t
∗
C
t
−
1
+
i
t
∗
C
~
t
C_t=f_t*C_{t-1}+i_t * \tilde C_t
Ct=ft∗Ct−1+it∗C~t
其实这里的
f
t
f_t
ft和
i
t
i_t
it可以看做两个权重,一个是遗忘(
C
t
−
1
C_{t-1}
Ct−1)权重,一个是记忆(
x
t
x_{t}
xt)权重。
(3)输出阶段
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个(经过tanh)过滤后的版本。
- 首先,我们运行一个 sigmoid 层来 x t x_t xt哪些内容将输出出去;
- 接着,我们把(更新后的)细胞状态 C t C_t Ct通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘;
- 最终我们仅仅会输出我们确定输出的那部分。
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
O
)
o_t=\sigma(W_o \cdot [h_{t-1},x_t]+b_O)
ot=σ(Wo⋅[ht−1,xt]+bO)
h
t
=
o
t
∗
t
a
n
h
(
C
t
)
h_t=o_t * tanh(C_t)
ht=ot∗tanh(Ct)
上面三步对应了三个门控,这三个门虽然功能上不同,但在执行任务的操作上是相同的。他们都是使用sigmoid函数作为选择工具,tanh函数作为变换工具,这两个函数结合起来实现三个门的功能。
三个步骤里的权重 W f , W i , W o W_f,W_i,W_o Wf,Wi,Wo都不一样且都是要学习的。
再看一下输出的
C
t
C_t
Ct和
h
t
h_t
ht:
C
t
=
f
t
∗
C
t
−
1
+
i
t
∗
C
~
t
C_t=f_t*C_{t-1}+i_t*\tilde C_t
Ct=ft∗Ct−1+it∗C~t
h
t
=
o
t
∗
t
a
n
h
(
C
t
)
h_t=o_t*tanh(C_t)
ht=ot∗tanh(Ct)
2.3 总结
以上,就是LSTM的内部结构。通过门控状态来控制传输状态,记住需要长时间记忆的,忘记不重要的信息;而不像普通的RNN那样只能够“呆萌”地仅有一种记忆叠加方式。对很多需要“长期记忆”的任务来说,尤其好用。
但也因为引入了很多内容,导致参数变多,也使得训练难度加大了很多。因此很多时候我们往往会使用效果和LSTM相当但参数更少的GRU
来构建大训练量的模型。
2.4 LSTM 如何避免梯度消失和梯度爆炸?
上面说了,RNN的梯度消失和爆炸主要是由这个 ∏ j = k + 1 t ∂ h j ∂ h j − 1 = ∏ j = k + 1 t t a n h ′ ⋅ W s \prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s} ∏j=k+1t∂hj−1∂hj=∏j=k+1ttanh′⋅Ws 引起的,对于LSTM同样也包含这样的一项,但是在LSTM中是这样的: ∏ j = k + 1 t ∂ h j ∂ h j − 1 = ∏ j = k + 1 t t a n h ′ ⋅ σ ( W f x t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot \sigma(W_f x_t+b_f)\approx0|1 ∏j=k+1t∂hj−1∂hj=∏j=k+1ttanh′⋅σ(Wfxt+bf)≈0∣1。
很显然里面这个 t a n h ′ ⋅ σ ( W f x t + b f ) {tanh^{'}}\cdot \sigma(W_f x_t+b_f) tanh′⋅σ(Wfxt+bf)相乘的结果不可能发生梯度消失和爆炸。
三、GRU
GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。
GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。因为GRU实验的实验效果与LSTM相似,但是更易于计算。
3.1 GRU的结构
GRU的输入输出结构与普通的RNN是一样的。只不过内部结构是在LSTM的基础上优化了。内部结构图如下:
(1)首先介绍GRU的两个门,分别是重置的门控
r
t
r_t
rt(reset gate) 和更新门控
z
t
z_t
zt(update gate) ,计算方法和LSTM中门的计算方法一致:
r
t
=
σ
(
W
r
⋅
[
h
t
−
1
,
x
t
]
)
r_t=\sigma(W_r \cdot [h_{t-1},x_t])
rt=σ(Wr⋅[ht−1,xt])
z
t
=
σ
(
W
z
⋅
[
h
t
−
1
,
x
t
]
)
z_t=\sigma(W_z \cdot [h_{t-1},x_t])
zt=σ(Wz⋅[ht−1,xt])
(2)然后是计算候选隐藏层
h
~
t
\tilde h_t
h~t(candidate hidden layer) ,这个候选隐藏层和LSTM中的
C
t
C_t
Ct是类似,可以看成是当前时刻的新信息,其中
r
t
r_t
rt 用来控制需要保留多少之前的记忆,比如如果
r
t
r_t
rt 为0,那么
h
~
t
\tilde h_t
h~t 只包含当前词的信息:
h
~
t
=
t
a
n
h
(
W
⋅
[
r
t
∗
h
t
−
1
,
x
t
]
)
\tilde h_t=tanh(W \cdot [r_t*h_{t-1},x_t])
h~t=tanh(W⋅[rt∗ht−1,xt])
h
~
t
\tilde h_t
h~t 的计算按下面这样看更清晰一些,黄色的线是
r
t
∗
h
t
−
1
r_t*h_{t-1}
rt∗ht−1,蓝色的线是
[
r
t
∗
h
t
−
1
,
x
t
]
[r_t*h_{t-1},x_t]
[rt∗ht−1,xt],红色的线是
W
⋅
[
r
t
∗
h
t
−
1
,
x
t
]
W \cdot [r_t*h_{t-1},x_t]
W⋅[rt∗ht−1,xt],然后再经过一层tanh:
(3)最后
z
t
z_t
zt 控制需要从前一时刻的隐藏层
h
t
−
1
h_{t-1}
ht−1 中遗忘多少信息
,需要加入多少当前时刻的隐藏层信息
h
~
t
\tilde h_t
h~t,最后得到当前位置的隐藏层信息
h
t
h_t
ht , 需要注意这里与LSTM的区别是GRU中没有output gate:
h
t
=
z
t
∗
h
~
t
+
(
1
−
z
t
)
∗
h
t
−
1
h_t=z_t*\tilde h_t+(1-z_t)*h_{t-1}
ht=zt∗h~t+(1−zt)∗ht−1
参考:
【1】RNN
【2】RNN梯度消失和爆炸的原因
【3】人人都能看懂的LSTM
【4】[译] 理解 LSTM 网络