写在前面
本文主要整理了最简单的seq2seq learning,以及加入Attention思想后的seq2seq model,大多数内容来自台大李宏毅老师的课上内容,加了一些自己理解的有疑问的点。
涉及论文:Sequence to Sequence Learning with Neural Networks.
seq2seq来源
DNN可以解决很多问题,但是其最大的问题便是不能考虑时序关系,另外,DNN要求每个batch输入的维度是一致了,这点在机器翻译过程中很难提前预知,因此出现了seq2seq模型。
seq2seq本质
seq2seq的本质是encoder-decoder模型,下面以翻译任务中的“汉译英”为例,比如,模型首先使用编码器对汉语进行编码,得到语义向量 c c c,然后使用解码器对 c c c进行解码,得到对应的英文翻译结果。由于encoder与decoder两端处理的都是序列数据,所以被称为sequence-to-sequence,简称seq2seq。目前应用最多的编/解码器是RNN(LSTM,GRU)。
seq2seq model
最简单的seq2seq模型如下图所示:
上图左边是Encoder过程,以"我是谁->Who am I"的翻译为例,则图中A表示“我”的向量表示,B表示“是”的向量表示,C表示“谁”的向量表示,是结束标志。
通过Encoder过程,生成了语义向量W。
右边是Decoder过程,首先将W喂进模型,生成X;接着因为序列之间的关系是相互影响的,因此除了将W喂进去之外,也将前面生成的X喂进去,生成Y;接着将W和Y喂进去生成Z…
这便是最简单的seq2seq模型,绿色的黄色的方框可以使用RNN/LSTN/RGU等。
seq2seq存在的问题
1、上图可以看出,在Decode过程中喂进去的是一样的W,但是对于翻译过程来说,每个输入进去的词对于当前翻译结果的权重影响都是不同的,比如,现在翻译任务是“机器学习->Machine Learning”,则对于正在翻译“Learning”的这个过程来说,“学习”的重要性远远大于“机器”。
2、另外,W需要存储前面所有词语提供的信息,序列的长度是一个非常大的瓶颈,如机器翻译问题,当要翻译的句子较长时,一个W可能存不下那么多信息,就会造成翻译精度的下降。因此,输入数据有所侧重,对于翻译精度的提高非常重要,因此出现了Attention-based seq2seq Model。
Attention-based seq2seq Model
下面首先通过一个动态图讲一下attention的计算过程:
如上图所示,A,B,C是输入数据,h1、h2、h3和h4是RNN/LSTM等在Encode后的memory数据(这里也可以使用y,但是y只是h转换后的一个结果,所以这里使用h即可)。得到h1、h2、h3和h4之后,提供向量
z
0
z^0
z0,其数据是训练得出。
接下来进行如下计算(相当于模型去搜索对当前翻译最有用的输入的过程):
①将
z
0
z^0
z0和h1经过match函数得到
α
0
1
\alpha_0^1
α01(一个标量)
②将
z
0
z^0
z0和h2经过match函数得到
α
0
2
\alpha_0^2
α02
③将
z
0
z^0
z0和h3经过match函数得到
α
0
3
\alpha_0^3
α03
④将
z
0
z^0
z0和h4经过match函数得到
α
0
4
\alpha_0^4
α04。
在得到
α
0
1
\alpha_0^1
α01、
α
0
2
\alpha_0^2
α02、
α
0
3
\alpha_0^3
α03和
α
0
4
\alpha_0^4
α04之后,我们进行如下操作:
即将
α
0
1
\alpha_0^1
α01、
α
0
2
\alpha_0^2
α02、
α
0
3
\alpha_0^3
α03和
α
0
4
\alpha_0^4
α04经过一层softmax之后得到
α
^
0
1
\hat{\alpha}_0^1
α^01、
α
^
0
2
\hat{\alpha}_0^2
α^02、
α
^
0
3
\hat{\alpha}_0^3
α^03和
α
^
0
4
\hat{\alpha}_0^4
α^04,假设
α
^
0
1
=
α
^
0
2
=
0.5
\hat{\alpha}_0^1=\hat{\alpha}_0^2=0.5
α^01=α^02=0.5,
α
^
0
3
=
α
^
0
4
=
0.0
\hat{\alpha}_0^3=\hat{\alpha}_0^4=0.0
α^03=α^04=0.0,则
C
1
=
∑
α
^
0
i
h
i
=
0.5
∗
h
1
+
0.5
∗
h
2
+
0.0
∗
h
3
+
0.0
∗
h
4
C1=\sum{\hat{\alpha}_0^i}hi=0.5*h1+0.5*h2+0.0*h3+0.0*h4
C1=∑α^0ihi=0.5∗h1+0.5∗h2+0.0∗h3+0.0∗h4,这就是第一个代替上述W的输入数据。得到第一个输入结果,我们经过Decoder过程得到第一个输出结果,并不断往下进行,过程如下:
其中,
z
0
,
z
1
,
z
2
,
.
.
.
z^0,z^1,z^2,...
z0,z1,z2,...都是需要训练的数据。
以上就是attention机制的seq2seq。
match函数通常有哪些?
①coscin函数
②简单的DNN网络
③类似
y
=
h
T
W
z
y=h^TWz
y=hTWz的变换
模型使用RNN还是LSTM?
我们都知道RNN和LSTM都能用来处理序列数据,当然LSTM是RNN的变体,但是为什么很多情况下选LSTM呢?这是因为RNN针对长序列可能会引起梯度消失/梯度爆炸的问题,如下图(来自台大李宏毅老师):
下面是一个简单的RNN模型,第一个输入是1,其余都是0,使用的激活函数是线性函数,因此最后一个输出是
W
99
W^{99}
W99,数值的微小变化都会引起梯度消失/梯度爆炸的问题,而LSTM中有forget gate,可能在某个时刻就将前面的记忆清空,有效地预防了这么问题。因此,一般对于处理长句子来说,LSTM优于RNN。
大概就到这,以后有什么新的想法再来补充!