本文首发于公众号【DeepDriving】,欢迎关注。
0. 前言
2017
年,谷歌研究人员在《Attention Is All You Need
》这篇论文中提出了Transformer
模型,该模型最初是被用于机器翻译任务中。由于其良好的可并行性和强大的特征提取能力,Transformer
模型在随后的几年中被用到自然语言处理、语音识别、计算机视觉等各个领域中,并表现出优异的性能。
本文基于论文的内容解读Transformer
模型的各个组成部分,然后用Python
实现一个完整的Transformer
模型。
1. Transformer模型结构解析
1.1 模型总体架构
Transformer
的总体架构如下图所示,模型包含一个编码器和解码器(分别对应下图中的左侧和右侧部分),编码器和解码器都是由一系列堆叠的注意力结构和全连接层组成。
编码器
编码器由 N = 6 N=6 N=6个相同的层组成,每个层又包含两个子层:第一个子层为多头自注意力机制,第二个子层为一个简单的全连接前馈网络。这两个子层都采用了残差连接结构,后面接一个LayerNorm
层,也就是说,每个子层的输出为 L a y e r N o r m ( x + S u b L a y e r ( x ) ) LayerNorm(x+SubLayer(x)) LayerNorm(x+SubLayer(x))。因为使用了残差连接结构,模型中所有子层,包括输入的Embedding
层,它们的输出维度 d m o d e l d_{model} dmodel都等于512
。
解码器
解码器也是由 N = 6 N=6 N=6个相同的层组成,除了使用了与编码器相同的子层外,解码器还在其中插入了第三个子层,这个子层对编码器的输出memory
执行多头注意力机制。与编码器类似的,解码器的子层也采用残差连接结构,后面再接一个LayerNorm
层。需要注意的是,解码器在多头自注意力子层中添加了一个掩码,这种机制可以确保对位置 i i i的预测只能依赖于小于位置 i i i的已知输出。
解码器的输出通过可学习的线性变换层和SoftMax
函数转换为预测下一个Token
的概率。
1.2 模型结构详解
1.2.1 注意力机制
注意力函数可以描述为将查询(query
)和一组键(key
)- 值(value
)对映射到输出,其中query
、key
、和value
都是向量。注意力函数的功能就是计算value
的加权和,其中分配给每个value
的权重由与query
和key
相关的特定函数计算得出。
缩放点积注意力
作者提出的注意力称为缩放点积注意力,它的输入是维度为 d k d_{k} dk的query
和key
,以及维度为 d v d_{v} dv的value
。对于输入的query
,首先计算它与key
的点积并除以缩放系数 d k \sqrt{d_{k}} dk,然后用一个SoftMax
函数来计算应用到value
上的权重,这个权重再与value
做点积运算得到最终结果。
在实际应用中,会把一组query
向量打包在一起组成矩阵Q
,相应的key
和value
也分别打包为矩阵K
和V
,然后同时用注意力函数进行计算:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dkQKT)V
为什么Q
和K
的点积结果要除以系数 d k \sqrt{d_{k}} dk?因为作者发现如果 d k d_{k} dk的值比较大,那么Q
和K
点积的结果会产生很大的值,这样经过SoftMax
函数后会产生非常小的梯度而不利于模型训练。为了消除这种影响,作者把点积结果除以一个系数 d k \sqrt{d_{k}} dk,这也是为什么作者把这种注意力称为缩放注意力的原因。
多头注意力
把输入的query
、key
和value
用不同的、可学习的线性映射操作分别映射h
次,映射后的维度分别为 d k d_{k} dk、 d k d_{k} dk和 d v d_{v} dv,然后每个映射的版本再并行地进行注意力计算,产生 d v d_{v} dv维度的输出结果。把这h
个输出的结果拼接到一起然后再做一次映射,使得最后输出结果的维度与原始输入相同。作者把这种多次映射再分别进行注意力计算的结构称为多头注意力,它比只使用一个维度为 d m o d e l d_{model} dmodel的query
、key
和value
来计算注意力的效果要好很多。
与单头注意力结构相比,多头注意力使得模型具备关注来自不同表示子空间信息的能力,模型的学习能力更强大。多头注意力机制其实就是将输入序列进行多组自注意力处理的过程,可以用公式表示为:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W