本文是transformer的学习笔记,主要参考的是李宏毅老师的课程,本文的部分截图也是李老师视频课程上的截图,视频课地址:【台大李宏毅21年机器学习课程 self-attention和transformer】 https://www.bilibili.com/video/BV1Xp4y1b7ih/?p=3&share_source=copy_web&vd_source=80b47af2047bee820c0a722c39d8af69
首先是transformer的架构功能的学习。
这张图是从transformer提出论文中截取的,原文名称《Attention is All you Need》。
在对transformer进行介绍之前,首先介绍self—attention机制,因为transformer中主要使用了这个东西。
下图是self—attention的计算使用。
接下来介绍self—attention这个block的主要计算原理。
其中主要涉及到三个矩阵q,k,v。
首先,将输入的每个向量和其对应的k相乘,得到每个向量对应的key,而后首先将第一个向量与其对应的q矩阵相乘,得到其query,而后通过将第一个向量的query与后面所有向量的key相乘,得到第一个向量和后面每个向量之间的关系,也就是后面每个向量对于第一个向量的重要性,他的关联程度,这个结果需要通过softmax层(右上角的公式是softmax的计算公式,exp表示自然对数为底的指数),从而将其归一化,当然也可以使用relu或者其他激活函数,具体问题具体分析,下图是计算示例。
至此,算出了表示每个输入向量和第一个向量的关联程度的系数alpha。
接下来,通过v矩阵,计算得到每个输入向量的value,而后将刚刚得到的关联程度alpha与value相乘,以考虑每个输入向量对第一个向量的影响,最后再累加就得到了第一个输出b1,计算公式如右上所示。
依次类推,重复上述步骤得到b2的输出,从而得到所有的输出,示例如下。
以上是从细化的角度,逐个计算的,其实可以通过矩阵运算进行并行运算从而提高效率,一次得到所有的结果,见下图。
首先将所有的输入拼成一个矩阵,分别与三个权重矩阵相乘,这三个权重矩阵Wq、Wk、Wv是需要通过学习得到的,计算后得到Q、K、V矩阵。
然后通过将K矩阵转置与Q矩阵相乘得到A矩阵,通过softmax归一化之后得到关注度矩阵A’。
将V矩阵与得到的关注度矩阵A’相乘,即可得到最终的输出矩阵O。
对于上述步骤的总结如下图所示。
至此,self—attention的计算原理就全部介绍完了,transformer中用到的为Multi—Head attention,这是self—attention的进阶版,接下来对此进行介绍。
对于多头,其实就是有多个Wq,Wk,Wv矩阵,从而得到多个输出,而后将所有输出拼接在一起,利用一个Linear变成一个输出,从而通过多头机制注意学习到不同方面的特征。
至此,self—attention的计算原理,全部介绍完成。
现在开始对每个部分逐一介绍,transformer主要是由Encoder和Decoder两个部分组成,首先是Encoder部分,下图是示意图,从课程中截取。
首先对Encoder输入数据,对于时间序列数据,需要考虑时间序列信息,然而transformer没有想LSTM和RNN那种的循环机制,因此需要通过Positional Encoding的方式进行位置编码,以此来考虑时间序列信息,详细的计算过程请参考其他文章,此处不过多赘述。
在完成Positional Encoding之后,输入到Multi—Head Attention之中计算,对于输出的结果加上前面经过Positional Encoding的值,这一步是参考了残差网络,如此可以保证在网络层数较深的时候避免梯度消失问题,与此同时,进行Layer Norm,而不是Batch Norm,对同一个神经元的不同feature进行归一化操作,因为对同一batch的不同feature进行归一化没有意义。
而后通过一个Feed forward层(可以是Linear)前向传播,在经过一次Add&Norm之后,即完成了本次Encoder的计算。
至此,Encorder部分全部介绍完成,接下来介绍Decorder部分。
下图展示的是Decorder的输入输出方式,首先输入一个BEGIN的开始信号,然后输出第一个机,而后将第一个输出作为输入,得到下一个输出,依次类推直到结束。
主要的计算步骤和Encoder的十分相似,要注意mask的问题,且中间多了一个cross—attention,用来接收从Encoder传递过来的输入,并且和上一步的输入结合到一块,具体的计算步骤下面将详细介绍。
其主要是将Encoder中的K、V矩阵输入到cross—attention中从而完成二者的互动,计算过程示意如下。
至此,计算过程的原理全部写完。