由于我没有读过原论文,该博客写的内容几乎来自于李宏毅老师的self-attention课程,链接在这里:《台大李宏毅21年机器学习课程 self-attention和transformer》。该博客用于梳理笔记,以便后面复习的时候使用。如果后面读了相关论文或者有了新的理解会进行更改补充。
1 self-attention的思想及框架
self-attention的核心思想是在预测结果的时候把注意力放在不同的特征上。适用于输入一堆向量的问题。如以下例子[1]:
self-attention的框架图如下:
通过这个框架图,我们知道输入的是一个序列(sequence),通过self-attention这个盒子后,输出相同数量的向量,这些向量是包含了上下文信息的。最后再将这些向量放入到一个全连接层(fully connected network)中,最终就可以得到结果。
当然,以我现在的知识水平来说,既然输入和输出都是序列信息,那么输入可以是RNN或其变体的输出,或者也可以将self-attention输出后的内容输入到RNN及其变体中,而非框架中写到的fully connected network。
2 self-attention工作流程
我们现在将注意力放到self-attention这个盒子来。根据上面的框架图我们可以知道其工作的输入是一组sequence,输出也是一组sequence。而其内部的结构是如下的:
光看这个框架图,会发现其实有点像一个全连接的神经网络,即
b
i
\mathbf{b}^i
bi 是由所有的
a
\mathbf{a}
a 产生的。但是如果只是一个全连接的神经网络也不会有现在这样的热度,所以其内部还有其他的内容。我们回到self-attention的核心思想上,其考虑的是针对不同特征的注意力,这个注意力我个人理解也就是某一种关联关系,刻画了两个向量之间的关联程度,所以我们现在的目标是如何衡量两个输入向量之间的关联性。我们以
a
1
\mathbf{a}^1
a1 输出
b
1
\mathbf{b}^1
b1 为例:
我们设
α
1
,
i
\alpha^{1, i}
α1,i 为
a
1
\mathbf{a}^1
a1 与
a
i
\mathbf{a}^i
ai 之间的关联性(这里
i
i
i 可以为1,即自己与自己的注意力)。接着设两个权重矩阵
W
q
\mathbf{W}^q
Wq 与
W
k
\mathbf{W}^k
Wk,用
W
q
⋅
a
1
\mathbf{W}^q · \mathbf{a}^1
Wq⋅a1 可以得到向量
q
\mathbf{q}
q,即Query
向量;用
W
k
⋅
a
i
\mathbf{W}^k · \mathbf{a}^i
Wk⋅ai 得到向量
k
\mathbf{k}
k,即Key
向量。接着对
q
\mathbf{q}
q 与
k
\mathbf{k}
k 求内积即可获得
α
1
,
i
\alpha^{1, i}
α1,i。其框架图如下:
上图左边这个框框就是内积求得
α
\alpha
α 的示意图。而右边的话则是另外一种求
α
\alpha
α 的方法,将
q
\mathbf{q}
q 与
k
\mathbf{k}
k 做加法,通过一个非线性变换,再经过一个矩阵进行线性变换,得到
α
\alpha
α。以上是
a
1
\mathbf{a}^1
a1 与某一个向量做的关联性计算,而self-attention是要
a
1
\mathbf{a}^1
a1 与每个输入向量都做关联性计算,所以其示意图如下:
针对
a
1
\mathbf{a}^1
a1 要求得两个向量
q
1
\mathbf{q}^1
q1 与
k
1
\mathbf{k}^1
k1,其余的向量只需要求得
q
i
,
i
≠
1
\mathbf{q}^i, i \neq 1
qi,i=1。接着对这些
α
1
,
i
\alpha^{1, i}
α1,i 求softmax得到
α
1
,
i
′
\alpha'_{1, i}
α1,i′。
最后我们再用向量
a
i
\mathbf{a}^i
ai 与权重矩阵
W
v
\mathbf{W}^v
Wv 相乘得到向量
v
i
\mathbf{v}^i
vi,即Value
向量。将向量
v
i
\mathbf{v}^i
vi 与
α
1
,
i
′
\alpha'_{1, i}
α1,i′ 相乘,再对所有的这些乘后的值求和,即可获得
b
1
b^1
b1,如下图:
从形式化表达的角度上来讲,以上所有内容的公式表达如下(所有的向量均为列向量):
Q = W q I \mathbf{Q}=\mathbf{W}^q \mathbf{I} Q=WqI
K = W k I \mathbf{K}=\mathbf{W}^k \mathbf{I} K=WkI
V
=
W
v
I
\mathbf{V}=\mathbf{W}^v \mathbf{I}
V=WvI
其中,
I
=
[
a
1
,
a
2
,
a
3
,
.
.
.
]
,
Q
=
[
q
1
,
q
2
,
q
3
,
.
.
.
]
,
K
=
[
k
1
,
k
2
,
k
3
,
.
.
.
]
,
V
=
[
v
1
,
v
2
,
v
3
,
.
.
.
]
\mathbf{I}=[ \mathbf{a}^1, \mathbf{a}^2, \mathbf{a}^3, ...], \mathbf{Q}=[ \mathbf{q}^1, \mathbf{q}^2, \mathbf{q}^3, ...], \mathbf{K}=[ \mathbf{k}^1, \mathbf{k}^2, \mathbf{k}^3, ...], \mathbf{V}=[ \mathbf{v}^1, \mathbf{v}^2, \mathbf{v}^3, ...]
I=[a1,a2,a3,...],Q=[q1,q2,q3,...],K=[k1,k2,k3,...],V=[v1,v2,v3,...]。
针对第一个input:
α
1
,
1
=
k
1
T
q
1
,
α
1
,
2
=
k
2
T
q
1
,
α
1
,
3
=
k
3
T
q
1
,
.
.
.
\alpha_{1, 1}={\mathbf{k}^1}^T\mathbf{q}^1, \alpha_{1, 2}={\mathbf{k}^2}^T\mathbf{q}^1, \alpha_{1, 3}={\mathbf{k}^3}^T\mathbf{q}^1, ...
α1,1=k1Tq1,α1,2=k2Tq1,α1,3=k3Tq1,...,用矩阵表达为:
α
1
=
K
T
q
1
\mathbf{\alpha}_1 = \mathbf{K}^T \mathbf{q}^1
α1=KTq1
因为 α i , j \alpha_{i, j} αi,j是个标量,所以 α 1 \mathbf{\alpha}_1 α1 是个列向量。对每个input做 α \alpha α 的计算,表达为:
A = K T Q \mathbf{A}=\mathbf{K}^T \mathbf{Q} A=KTQ
其中,
A
\mathbf{A}
A为:
A
=
[
α
1
,
1
α
2
,
1
α
3
,
1
⋯
α
1
,
2
α
2
,
2
α
3
,
2
⋯
α
1
,
3
α
2
,
3
α
3
,
3
⋯
⋮
⋮
⋱
⋮
]
\mathbf{A}= \begin{bmatrix} \alpha_{1, 1} & \alpha_{2, 1} & \alpha_{3, 1} & \cdots \\ \alpha_{1, 2} & \alpha_{2, 2} & \alpha_{3, 2} & \cdots \\ \alpha_{1, 3} & \alpha_{2, 3} & \alpha_{3, 3} & \cdots \\ \vdots & \vdots & \ddots & \vdots \end{bmatrix}
A=⎣⎢⎢⎢⎡α1,1α1,2α1,3⋮α2,1α2,2α2,3⋮α3,1α3,2α3,3⋱⋯⋯⋯⋮⎦⎥⎥⎥⎤
接着对
A
\mathbf{A}
A 做softmax得到
A
′
\mathbf{A}'
A′,其中
A
′
\mathbf{A}'
A′为:
A
′
=
[
α
1
,
1
′
α
2
,
1
′
α
3
,
1
′
⋯
α
1
,
2
′
α
2
,
2
′
α
3
,
2
′
⋯
α
1
,
3
′
α
2
,
3
′
α
3
,
3
′
⋯
⋮
⋮
⋱
⋮
]
\mathbf{A}'= \begin{bmatrix} \alpha_{1, 1}' & \alpha_{2, 1}' & \alpha_{3, 1}' & \cdots \\ \alpha_{1, 2}' & \alpha_{2, 2}' & \alpha_{3, 2}' & \cdots \\ \alpha_{1, 3}' & \alpha_{2, 3}' & \alpha_{3, 3}' & \cdots \\ \vdots & \vdots & \ddots & \vdots \end{bmatrix}
A′=⎣⎢⎢⎢⎡α1,1′α1,2′α1,3′⋮α2,1′α2,2′α2,3′⋮α3,1′α3,2′α3,3′⋱⋯⋯⋯⋮⎦⎥⎥⎥⎤
最后用 V \mathbf{V} V 与 A ′ \mathbf{A}' A′ 做矩阵乘法得到最终输出的矩阵 O \mathbf{O} O:
O = V A ′ \mathbf{O}=\mathbf{V} \mathbf{A}' O=VA′
其中 O = [ b 1 , b 2 , b 3 , . . . ] \mathbf{O}=[\mathbf{b}^1, \mathbf{b}^2, \mathbf{b}^3, ...] O=[b1,b2,b3,...]。因为 A i , j ′ = α i , j \mathbf{A}'_{i, j}=\alpha_{i,j} Ai,j′=αi,j为标量,以 b 1 \mathbf{b}^1 b1 为例,这一步的计算过程为:
b 1 = α 1 , 1 ′ v 1 + α 1 , 2 ′ v 2 + α 1 , 3 ′ v 3 + . . . = ∑ i = 1 n α 1 , i ′ v i \mathbf{b}^1 = \alpha_{1, 1}'\mathbf{v}^1 + \alpha_{1, 2}'\mathbf{v}^2 + \alpha_{1, 3}'\mathbf{v}^3 + ... = \sum_{i=1}^n\alpha_{1, i}'\mathbf{v}^i b1=α1,1′v1+α1,2′v2+α1,3′v3+...=i=1∑nα1,i′vi
整体的计算流程就如下图所示:
其中需要训练的参数只有三个,即
W
q
,
W
k
,
W
v
\mathbf{W}^q, \mathbf{W}^k, \mathbf{W}^v
Wq,Wk,Wv。
根据以上流程,我们就可以发现self-attention的优点:可以并行化计算
。对比类似的RNN及其变体,我们发现,RNN及其变体是需要上一时刻的输出作为下一时刻的输入,即上一时刻没计算出来无法计算下一时刻的内容。但是self-attention不需要考虑这些,因为输入的是整个sequence,就可以直接采用矩阵进行运算。
3 Multi-head self-attension
Multi-head self-attension(多头自注意力)是self-attention的进阶版,其与self-attention的区别就在于,对于每一个输入向量,会有多个
q
,
k
,
v
\mathbf{q}, \mathbf{k}, \mathbf{v}
q,k,v,而不同的
q
,
k
,
v
\mathbf{q}, \mathbf{k}, \mathbf{v}
q,k,v 有不同的
W
q
,
W
k
,
W
v
\mathbf{W}^q, \mathbf{W}^k, \mathbf{W}^v
Wq,Wk,Wv(就是说需要训练不同的
W
q
,
W
k
,
W
v
\mathbf{W}^q, \mathbf{W}^k, \mathbf{W}^v
Wq,Wk,Wv)。以两个头的自注意力来说,如下图:
最后,乘以一个矩阵
W
o
\mathbf{W}^o
Wo 做线性变换,得到
b
i
\mathbf{b}^i
bi:
4 几个tricks
而self-attention使用的过程中会有些问题。
第一个问题就是,self-attention处理的是序列问题,但是我们发现其输入序列里面并没有顺序(因为每个向量的Query都会与所有向量的Key做乘法,再softmax后与每个的Value相乘求和),所以
a
1
\mathbf{a}^1
a1与
a
2
,
a
3
,
.
.
.
\mathbf{a}^2, \mathbf{a}^3, ...
a2,a3,...并无区别。那么针对这个问题,解决方案就是positional encoding
(位置编码)。对每个位置的向量,求得其位置信息向量
e
i
\mathbf{e}^i
ei,加到每个输入向量
a
i
\mathbf{a}^i
ai中即可。
第二个问题就是,如果sequence过长,那么我们在训练的过程中,每个 a i \mathbf{a}^i ai 会乘上许多的东西,而导致运行速度减慢以及内存爆炸的问题。于是如果输入的sequence过长,那么就人为设置一个窗口来限制注意力的范围。直觉上讲,就是句子过长后可能前后就毫无联系了。
5 参考
[1] 李rumor. NLP中的Attention原理和源码解析[EB/OL]. (2021-05-01)[2021-08-18]. https://zhuanlan.zhihu.com/p/43493999
[2] 爱学习的凉饭爷. 台大李宏毅21年机器学习课程 self-attention和transformer[EB/OL]. (2021-03-30)[2021-08-18]. https://www.bilibili.com/video/BV1Xp4y1b7ih?p=1
[3] Chihk-Anchor. transformer 模型(self-attention自注意力)[EB/OL]. (2019-01-08)[2021-08-18]. https://blog.csdn.net/weixin_40871455/article/details/86084560