1、Transformer简介
Transformer提出自谷歌2017年的论文(Attention is all your need)不同于之前使用RNN、LSTM、GRU、CNN来作为encoder和decoder,Transformer完全使用注意力机制的模型。之所以选择弃用循环神经网络,是因为循环神经网络的每一次运算,都需要上一时刻的隐藏态数据,导致计算不能并行,增加了计算时间,也浪费了计算机性能。
Transformer的基本模型结构如下图所示,本文会根据论文和源码对模型做出我认为最通俗的解释。
2、Attention注意力机制
Transformer的核心是注意力机制,Transformer几乎全部由注意力机制来实现,论文的名字Attention is all your need说的也正是这个意思。所以我们首先来了解什么是注意力机制
2.1、什么是注意力机制
注意力机制可以看成一个问答对,由问题到答案之间需要一个key,就好像有人问了你一个问题(query),“你好呀,今天天气怎么样?” 你给出的答案(answer)可能为 “今天天气不好” 。人在得到这个答案的时候,首先是有一个问题,根据问题联想到两个重要的词为今天、天气,给出的答案也是根据这两个词,这两个词就是key,根据key,我们会找到脑海中的天气情况(value),最后根据我们知道的天气情况,说出答案。
注意力机制的实现就是通过三个向量,Q、K、V,可以看做QK到输出的映射。注意力机制输出是V的加权求和,其中分配给每个值的权重是通过q与k的相应函数(compatibility function)的计算的, 对应上一段,就是我们根据Q、K的矩阵计算,得到一个权重,这个权重的作用是选出V中有用的信息,也就是根据“今天”“天气”来从脑子中的众多信息中选出哪些是有用的。2.1节中会进一步解释。
2.2、Transformer中的QKV计算方法
Transformer中,QKV矩阵的计算方式如下,其中,
d
k
d_k
dk表示K的维度:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
{\rm Attention}(Q,K,V)={\rm softmax}(\frac{QK^T}{\sqrt{d_k}})V
Attention(Q,K,V)=softmax(dkQKT)V
论文中给出了上述公式的计算流程图如下:
2.2.1、计算过程
上面讲的,所谓注意力机制的输出,可以根据上图做出解释。Transformer中的Q、K、V向量的计算,首先是Q、K做点积,Scale就是除以
d
k
\sqrt{d_k}
dk,Mask是可选择的,这里先不做解释。经过softmax后,就得到了各个值的权重,与V做点积,得到的就是输出。
【举个栗子】
- 假设Q\K\V的shape为 [ 7 , 3 ] [7,3] [7,3],其中7为句子长度,3各个向量的维度。
- Q K T QK^T QKT结果的shape为 [ 7 , 7 ] [7, 7] [7,7],这里做点积的过程中,每个字的向量和其他所有字的向量进行一次点积,得到的矩阵中,第i行向量表示的就是第i个字和其他所有字的向量的点积结果。
- 这个结果经过softmax,得到的 [ 7 , 7 ] [7,7] [7,7]的矩阵,就是Q与K的相应函数(compatibility function,这个不会翻译,应该叫啥形象一点),其中每个元素都表示概率,所谓加权求和,就是矩阵与V矩阵相乘时,得到的 [ 7 , 3 ] [7,3] [7,3]的矩阵,其中每个元素都是V与相应概率向量做点积的结果,得到的就是Transformer中注意力机制的输出。
2.2.2、为什么除以 d k \sqrt{d_k} dk
这里为什么除以
d
k
\sqrt{d_k}
dk,论文中给出了解释——除以
d
k
\sqrt{d_k}
dk是为了防止做softmax后出现梯度过小的情况发生,也就是防止梯度消失。softmax公式如下,除
d
k
\sqrt{d_k}
dk后的效果如图所示,原本
z
i
,
z
i
+
1
z_i,z_{i+1}
zi,zi+1这些梯度接近于0的点出现在了
z
i
∗
,
z
i
+
1
∗
z_i^*,z_{i+1}^*
zi∗,zi+1∗的位置上,变得有了梯度。
s
o
f
t
m
a
x
(
z
i
)
=
e
z
i
∑
j
=
0
N
e
z
j
\large {\rm softmax}(z_i)=\frac{e^{z_i}}{\sum_{j=0}^Ne^{z_j}}
softmax(zi)=∑j=0Nezjezi
至于为什么除以
d
k
\sqrt{d_k}
dk,网上的解释是,点积后的矩阵符合
N
(
0
,
d
k
)
N(0, d_k)
N(0,dk)的正态分布,除以
d
k
\sqrt{d_k}
dk可以符合
N
(
0
,
1
)
N(0, 1)
N(0,1)的正态分布,开始还不理解,写着写着突然明白了。首先,Q,K,V都是根据输入计算出来的,如果进行一次标准化,其计算结果就是符合
N
(
0
,
1
)
N(0,1)
N(0,1)的标准正态分布,那么两个标准正态分布的矩阵Q、K的单个元素计算结果,其分布应满足:
E
(
Q
i
j
K
m
n
)
=
E
(
Q
)
E
(
K
)
=
0
×
0
=
0
D
(
Q
i
j
K
m
n
)
=
E
(
(
Q
K
)
2
)
−
(
E
(
Q
K
)
)
2
=
E
(
Q
2
)
E
(
K
2
)
−
0
2
=
[
D
(
Q
)
−
(
E
(
Q
)
)
2
]
×
[
D
(
K
)
−
(
E
(
K
)
)
2
]
=
1
×
1
=
1
E(Q_{ij}K_{mn})=E(Q)E(K) = 0\times0=0\\ D(Q_{ij}K_{mn})=E((QK)^2)-(E(QK))^2=E(Q^2)E(K^2)-0^2\\ =[D(Q)-(E(Q))^2]\times[D(K)-(E(K))^2]\\=1\times1=1
E(QijKmn)=E(Q)E(K)=0×0=0D(QijKmn)=E((QK)2)−(E(QK))2=E(Q2)E(K2)−02=[D(Q)−(E(Q))2]×[D(K)−(E(K))2]=1×1=1
根据矩阵运算的规则,运算后的元素
b
i
j
=
∑
Q
[
i
,
0
:
d
q
]
K
[
0
:
d
k
,
j
]
b_{ij}=\sum{Q_{[i,0:d_q]}K_{[0:d_k,j]}}
bij=∑Q[i,0:dq]K[0:dk,j],所以
b
i
j
b_{ij}
bij的概率分布符合:
E
(
Q
K
)
=
∑
i
=
0
d
k
E
(
Q
i
j
K
m
n
)
=
0
D
(
Q
K
)
=
∑
i
=
0
d
k
D
(
Q
i
j
K
m
n
)
=
1
×
d
k
=
d
k
E(QK)=\sum_{i=0}^{d_k}E(Q_{ij}K_{mn})=0\\ D(QK)=\sum_{i=0}^{d_k}D(Q_{ij}K_{mn})=1\times d_k=d_k
E(QK)=i=0∑dkE(QijKmn)=0D(QK)=i=0∑dkD(QijKmn)=1×dk=dk
因为方差为
d
k
d_k
dk,所以除以
d
k
\sqrt{d_k}
dk后重新符合标准正态分布。
2.3、self-attention
transformer的基本网络结构就是其提出的self-attention。
2.3.1、什么是self-attention
先说说为什么叫self-attention,attention机制中,query来自于source_sequence,而key和value来自于knowledge_sequence,也就是Q来自于一个输入,K和V来自于另一个输入。而self-attention中,QKV均来自于同一输入,这就是self的含义。源码链接, 第4423行
源码如下,这里的memory_antecedent就是K,V的输入,如果其值为None,则令其等于query_antecedent,也就是Q的输入,在transformer源码中,传输的memory_antecedent为None,并且会进行一次判定,如果不为None,会raise ValueError,判定的代码在上面链接的第4620行。
if memory_antecedent is None:
memory_antecedent = query_antecedent
q = compute_attention_component(
query_antecedent,
total_key_depth,
q_filter_width,
q_padding,
"q",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
k = compute_attention_component(
memory_antecedent,
total_key_depth,
kv_filter_width,
kv_padding,
"k",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
v = compute_attention_component(
memory_antecedent,
total_value_depth,
kv_filter_width,
kv_padding,
"v",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
2.3.2、为什么提出self-attention
论文里从三个方面解释了为什么要使用self-attention,也就是从如下三个方面比较了self-attention的优势。
- 每层计算的复杂度
- 在最小的顺序操作数量下的并行计算量
- 网络中长距离依赖之间的路径长度
2.3.2.1、self-attention的复杂度
所谓每层计算的复杂度,指的是每层网络计算的时间复杂度。前面已经介绍过self-attention的计算方法,包括两次矩阵运算,一次缩放运算(除以 d k \sqrt{d_k} dk),一次softmax。这里对self-attention各部分的时间复杂度进行分析。
【矩阵点乘的复杂度】
假设矩阵A的维度为
[
m
×
n
]
[m\times n]
[m×n],矩阵B的维度为
[
n
×
l
]
[n\times l]
[n×l],矩阵点乘的过程为矩阵A的第i行的所有元素与矩阵B的第j列的所有元素乘积的和,作为结果矩阵中index为
[
i
,
j
]
[i,j]
[i,j]的元素,假设结果矩阵为C,C的维度为
[
m
×
l
]
[m\times l]
[m×l],
C
i
,
j
C_{i,j}
Ci,j的计算方法如下:
C
i
,
j
=
∑
k
=
0
n
−
1
A
i
,
k
×
B
k
,
j
C_{i,j} = \sum_{k=0}^{n-1}A_{i,k}\times B_{k,j}
Ci,j=k=0∑n−1Ai,k×Bk,j
上述运算中,包含n次乘法运算和n-1次加法运算,计算次数为
2
n
−
1
2n-1
2n−1。相同的计算需要进行
m
×
l
m\times l
m×l次,所以矩阵点乘的时间复杂度为:
(
2
n
−
1
)
×
m
×
l
=
O
(
n
m
l
)
(2n-1)\times m\times l=O(nml)
(2n−1)×m×l=O(nml)
因为Q、K、V矩阵的维度均为
[
n
×
d
]
[n\times d]
[n×d],n为句子长度,d为向量的维度,所以有:
- Q K T QK^T QKT的时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d),结果矩阵的维度为 [ n × n ] [n\times n] [n×n]
- 第二次矩阵运算的时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d),结果矩阵的维度为 [ n × d ] [n\times d] [n×d]
综上所述,两次矩阵运算的时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d)
【缩放运算的复杂度】
缩放运算的运算就是对矩阵中的每个元素除以 d k \sqrt{d_k} dk,缩放运算时,矩阵中元素个数为 n × n n\times n n×n,所以缩放运算的时间复杂度为 O ( n 2 ) O(n^2) O(n2)
【softmax的时间复杂度】
softmax的计算过程如下:
p
=
e
z
j
∑
i
=
0
n
−
1
e
z
i
p = \frac{e^{z_j}}{\sum_{i=0}^{n-1}e^{z_i}}
p=∑i=0n−1eziezj
self-attention中,softmax操作的矩阵维度为
[
n
×
n
]
[n\times n]
[n×n],操作针对的是最后一维,也就是矩阵中的每一行。对每一行的操作,包括如下两个步骤:
- ∑ i = 0 n − 1 e z i \sum_{i=0}^{n-1}e^{z_i} ∑i=0n−1ezi,计算包括求 e z i e^{z_i} ezi和求和,运算次数为 n + n − 1 = 2 n − 1 n + n-1=2n-1 n+n−1=2n−1
- e z j ∑ i = 0 n − 1 e z i \large\large \frac{e^{z_j}}{\sum_{i=0}^{n-1}e^{z_i}} ∑i=0n−1eziezj,因为分母已经求得,所以只包含一次计算。
所以每一行的计算次数为
2
n
−
1
+
n
=
3
n
−
1
2n-1+n=3n-1
2n−1+n=3n−1。
因为矩阵有n行,所以softmax的计算次数为
n
×
(
3
n
−
1
)
=
O
(
n
2
)
n\times (3n-1)=O(n^2)
n×(3n−1)=O(n2)
综上所述,self-attention每层的时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d)。其实后来有论文针对这个时间复杂度做过简化,前几天笔者还看到一篇谷歌的论文,用泰勒展开简化计算过程,下篇博客应该会讲这片论文,毕竟比较新,也不能总是写经典论文的笔记。
2.3.2.2、单层循环神经网络及单层卷积的时间复杂度
循环神经网络以RNN为例