Sequence to Sequence Learning with Neural Networks,从RNN开始
Sequence to Sequence Learning with Neural Networks这篇文章是Google在2014年发表的,较早的使用了Seq2Seq结构的文章,实现了从输入序列映射到不等长的输出序列的学习,在机器翻译的任务中,取得了非常好的成绩。作者首先指出深度神经网络能够在困难的学习任务中达到卓越的性能,但并不适用于从序列映射到到未知长度序列,如机器翻译和语音识别。本文建立模型的主要思想是,使用多层LSTM(长短期记忆网络,Long Short-Term Memory)将输入序列映射成一个固定维度的向量,然后使用另外一个多层LSTM从该向量中解码出目标序列。也就是Sequence to Sequence,经常缩写为Seq2Seq。下面将从什么是RNN开始介绍:
什么是RNN
RNN(循环神经网络,Recurrent Neural Network)是更好地处理时序信息而设计的。它引入状态变量来存储过去的信息,并用其与当前的输入共同决定当前的输出。
时序信息,以语言模型为例,假设一段长度为T的文本中的词依次为
x
1
,
x
2
,
…
x
T
x_1,x_2,…x_T
x1,x2,…xT那么在离散的时间序列中,
x
t
(
1
≤
t
≤
T
)
x_t(1≤t≤T)
xt(1≤t≤T)可以看做时间步t的序列信息。
RNN基本模型如下,
每个时间步为一个单元,有两个输入,一个是当前时间步的序列信息
x
t
x_t
xt,另一个是由上一个时间步计算输出的状态信息
h
t
h_t
ht。在A中的操作一般为全连接和激活函数。
h
t
=
f
h
(
x
t
,
h
t
−
1
)
=
t
a
n
h
(
W
h
x
x
t
+
W
h
h
h
t
−
1
+
b
h
)
y
t
=
f
y
(
h
t
)
=
s
o
f
t
m
a
x
(
W
y
h
t
)
\begin{aligned} & h_t=f_h(x_t,h_{t-1})=tanh(W_{hx}x_t+W_{hh}h_{t-1} + bh) \\ & y_t=f_y(h_t)=softmax(W_yh_t) \end{aligned}
ht=fh(xt,ht−1)=tanh(Whxxt+Whhht−1+bh)yt=fy(ht)=softmax(Wyht)
但这种最简单的RNN结构有一定局限性,如果一个序列足够长,那它们很难把信息从较早的时间步传输到后面的时间步,如下图可以一定程度反映这种情况。所以提出了LSTM模型。
什么是LSTM模型
LSTM多了一个表示cell记忆的值。LSTM为克服短期记忆问题的解决方案是,它们引入称作“门”的内部机制,可以调节信息流。这些门结构可以学习序列中哪些数据是要保留的重要信息,哪些是要删除的。通过这样做,它可以沿着长链序列传递相关信息来执行预测。几乎所有基于RNN的先进结果都是通过LSTM和其变种GRU实现的。
结构图和公式如下:
c
~
<
t
>
=
t
a
n
h
(
W
c
[
a
<
t
−
1
>
,
x
<
t
>
]
+
b
c
)
Γ
u
=
σ
(
W
u
[
a
<
t
−
1
>
,
x
<
t
>
]
+
b
u
)
Γ
f
=
σ
(
W
f
[
a
<
t
−
1
>
,
x
<
t
>
]
+
b
f
)
Γ
o
=
σ
(
W
o
[
a
<
t
−
1
>
,
x
<
t
>
]
+
b
o
)
c
<
t
>
=
Γ
u
∗
c
~
<
t
>
+
Γ
f
∗
c
<
t
−
1
>
a
<
t
>
=
Γ
o
∗
t
a
n
h
(
c
<
t
>
)
\begin{aligned} &\widetilde c^{<t>}=tanh(W_c[a^{<t-1>},x^{<t>}]+b_c)\\ &\Gamma_u=\sigma(W_u[a^{<t-1>},x^{<t>}]+b_u)\\ &\Gamma_f=\sigma(W_f[a^{<t-1>},x^{<t>}]+b_f)\\ &\Gamma_o=\sigma(W_o[a^{<t-1>},x^{<t>}]+b_o)\\ &c^{<t>}=\Gamma_u*\widetilde c^{<t>}+\Gamma_f*c^{<t-1>}\\ &a^{<t>}=\Gamma_o*tanh(c^{<t>}) \end{aligned}
c
<t>=tanh(Wc[a<t−1>,x<t>]+bc)Γu=σ(Wu[a<t−1>,x<t>]+bu)Γf=σ(Wf[a<t−1>,x<t>]+bf)Γo=σ(Wo[a<t−1>,x<t>]+bo)c<t>=Γu∗c
<t>+Γf∗c<t−1>a<t>=Γo∗tanh(c<t>)
图中的a就是之前所说的h,表示隐藏状态信息。从图中可以看出,输入值和输出值的个数都从2变成了3,多了一个表示记忆的值C;单元内部多了forget gate遗忘门,update gate更新门(输入门)和output gate输出门。
遗忘门、更新门和输出门的计算方式相同,各自权重和
[
a
t
−
1
,
x
t
]
[a_{t-1},x_t]
[at−1,xt]相乘加上偏置,得到的值用sigmoid函数激活为一个0到1之间的表示比例的值。
从多出来的输入值C着手,就有了两个问题,多出来的输入值
C
C
C如何在每一个时间步中更新自己?
C
C
C是怎样起作用的?
- 现将 C t − 1 C_{t-1} Ct−1与 Γ f \Gamma_f Γf元素对应位置相乘,来把上一时间步传进来的 C C C的值“忘记”一部分。
- 用当前时间步的另外两个输入值 x t x_t xt和 a t − 1 a_{t-1} at−1得到 C C C的更新值 C ~ t \widetilde C_t C t
- 用更新门控制更新值的分量加到“忘记”过的 C t − 1 C_{t-1} Ct−1中,就得到了当前时间步的 C C C的值 C t C_t Ct
- 将 C t C_t Ct用双曲正切函数激活得到一个-1到1之间的值,再与输出门元素对应位置相乘,就得到当前时间步的状态值。一方面传递到下一个时间步进行计算,一方面用来计算当前时间步的输出值y
回到论文中
论文中模型结构的主要思想是,用多层LSTM将输入序列映射到定长变量,再由定长变量通过多层LSTM解码出序列,W左边的部分为编码器encoder,右边的部分为解码器decoder,解码器预测到句尾符号<EOS>(视为一个特殊的单词)时停止解码,这使得模型能够生成不定长度的序列。如下图所示,输入序列为ABC以及序列结束符号<EOS>,映射到定长变量v,生成输出序列WXYZ以及序列结束符号<EOS>。
解码器中将每个时间步的输出值作为下一个时间步的输入值进行预测,体现了一种条件概率的思想。
P
(
y
1
,
…
,
y
T
′
)
=
∏
t
=
1
T
′
P
(
y
t
∣
v
,
y
1
,
…
,
y
t
−
1
)
P(y_1,…,y_{T'})=\prod_{t=1}^{T'}P(y_t|v,y_1,…,y_{t-1})
P(y1,…,yT′)=t=1∏T′P(yt∣v,y1,…,yt−1)
P
(
y
1
,
y
2
,
…
,
y
T
′
)
=
P
(
y
1
∣
v
)
P
(
y
2
∣
v
,
y
1
)
…
P
(
y
T
′
∣
v
,
y
1
,
y
2
,
…
,
y
T
′
−
1
)
P(y_1,y_2,…,y_{T'})=P(y1|v)P(y2|v,y1)…P(y{T'}|v,y_1,y_2,…,y_{T'-1})
P(y1,y2,…,yT′)=P(y1∣v)P(y2∣v,y1)…P(yT′∣v,y1,y2,…,yT′−1)
在这个等式中,每个
P
(
y
t
∣
v
,
y
1
,
…
,
y
t
−
1
)
P(y_t|v,y_1,…,y_{t-1})
P(yt∣v,y1,…,yt−1)分布用词汇表中所有单词的softmax表示。
此外还有以下几个关键点:
- 模型中encoder和decoder使用了不同的LSTM模型
因为这样做可以增加模型参数的数量,但计算代价可忽略不计,并且很自然的可以在多语言对上训练LSTM。 - 输入序列时,将序列逆序输入到模型中取得了更好的效果
如图中,要将ABC翻译为WXYZ的话,输入序列应为CBA。论文坦诚表明没有给出完善的理论依据,给出的解释是,顺序与逆序输入并没有改变平均距离(这里距离指的是time step diff),但是却让源句子与翻译目标语句开头的几个词的距离变短了,也就是加强了源句子中第一个单词A与翻译目标语句中第一个单词W的联系,而句子末尾的词距离变长的代价似乎并不显著,因此逆序输入会得到更好的效果。更为精准的翻译语句开头的单词,提升了句子翻译水平。 - 深层次的LSTM模型比浅层次的模型要好,文章中使用了4层LSTM
- 在decoder中应用了beam search
假设词表大小为3,内容为a, b, c。 beam size为2
- 生成第1个词的时候,选择概率最大的两个词,假设为a, c,那么当前序列就是a, c
- 生成第2个词的时候,我们将当前序列a和c,分别与词表中的所有词进行组合,得到新的6个序列aa ab ac ca cb cc,然后从其中选择2个得分最高的,当作当前序列,假如为aa cb
- 后面会不断重复这个过程,直到遇到结束符为止。最终输出2个得分最高的序列。
参考资料
https://zhuanlan.zhihu.com/p/46981722
https://towardsdatascience.com/illustrated-guide-to-recurrent-neural-networks-79e5eb8049c9
https://www.cnblogs.com/zuotongbin/p/10698843.html