参考链接:霹雳吧啦博客
主要参考了b站霹雳吧啦的视频,以及对应的帖子。记录Self-Attention
以及Multi-Head Attention
的理论。
前言
首先附上霹雳吧啦绘制的Self-Attention图示。
Self-Attention
在下图中,输入的序列长度为2,输入分别为 x 1 x_1 x1, x 2 x_2 x2。 x 1 x_1 x1, x 2 x_2 x2通过Input Embedding即 f ( x ) f\left(x\right) f(x)将输入映射到 a 1 a_1 a1, a 2 a_2 a2。然后将两个输入 a 1 a_1 a1, a 2 a_2 a2分别通过3个变换矩阵 W q W_q Wq, W k W_k Wk, W v W_v Wv得到对应的 q i q^i qi, k i k^i ki, v i v^i vi。
具体来说:
- q q q代表query,后续会与每一个键 k k k进行匹配。
- k k k代表key,后续会被每个 q q q匹配。
- v v v代表从高维映射后的 a a a中提取得到的信息。
- 后续 q q q和 k k k匹配的过程可以理解成计算两者的相关性,相关性越大对应 v v v的权重也就越大。
首先,Self-Attention的计算方式为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention} (Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dkQKT)V
1、计算
Q
Q
Q,
K
K
K,
V
V
V:
首先将原始输入分别经过Input Embedding即
f
(
x
)
f\left(x\right)
f(x)将输入映射到高维的
a
1
a_1
a1,
a
2
a_2
a2:
x 1 , x 2 → f ( x ) → a 1 , a 2 x_1, x_2 \rightarrow f(x) \rightarrow a_1, a_2 x1,x2→f(x)→a1,a2
接着利用3个变换矩阵
W
q
W_q
Wq,
W
k
W_k
Wk,
W
v
W_v
Wv(nn.Linea())得到
a
i
a_i
ai对应的
q
i
q^i
qi,
k
i
k^i
ki,
v
i
v^i
vi(过程如上图右下角所示):
若
a
1
=
(
1
,
1
)
a_1 = \text(1,1)
a1=(1,1),
a
2
=
(
1
,
0
)
a_2 = \text(1,0)
a2=(1,0),
W
q
=
(
1
1
0
1
)
W_q = \left( \begin{smallmatrix} 1 & 1 \\ 0 & 1 \end{smallmatrix} \right)
Wq=(1011),那么:
q
1
=
(
1
,
1
)
(
1
,
1
0
,
1
)
=
a
1
⋅
W
q
=
(
1
,
2
)
,
q
2
=
(
1
,
1
)
(
1
,
1
0
,
1
)
=
a
2
⋅
W
q
=
(
1
,
1
)
q^1 = \text(1,1) \binom{1,1}{0,1} = a_1 · W^q = \text(1,2), \ q^2 = \text(1,1) \binom{1,1}{0,1} = a_2 · W^q = \text(1,1)
q1=(1,1)(0,11,1)=a1⋅Wq=(1,2), q2=(1,1)(0,11,1)=a2⋅Wq=(1,1)
当Transformer进行并行化处理时,则可总结为:
(
q
1
q
2
)
=
(
1
,
1
1
,
0
)
(
1
,
1
0
,
1
)
=
(
1
,
2
1
,
1
)
\binom{q^1}{q^2} = \binom{1,1}{1,0} \binom{1,1}{0,1} = \binom{1,2}{1,1}
(q2q1)=(1,01,1)(0,11,1)=(1,11,2)
同理可以得到
(
k
1
k
2
)
\left( \begin{smallmatrix} k^1 \\ k^2 \end{smallmatrix} \right)
(k1k2)和
(
v
1
v
2
)
\left( \begin{smallmatrix} v^1 \\ v^2 \end{smallmatrix} \right)
(v1v2),则
(
q
1
q
2
)
\left( \begin{smallmatrix} q^1 \\ q^2 \end{smallmatrix} \right)
(q1q2)即为
Q
Q
Q,
(
k
1
k
2
)
\left( \begin{smallmatrix} k^1 \\ k^2 \end{smallmatrix} \right)
(k1k2)为
K
K
K,
(
v
1
v
2
)
\left( \begin{smallmatrix} v^1 \\ v^2 \end{smallmatrix} \right)
(v1v2)为
V
V
V。接着使用
q
1
q^1
q1与
{
k
1
,
k
2
,
.
.
.
,
k
n
}
\{k^1,k^2,... ,k^n\}
{k1,k2,...,kn}中的每个
k
k
k进行匹配(点乘操作),并除以
d
\sqrt{d}
d得到
α
α
α,其中
d
d
d为向量
k
i
k^i
ki的长度。通过除以
d
\sqrt{d}
d来缩放点乘后的数值,防止后续经过softmax时梯度过小```。计算
α
1
,
i
α_{1,i}
α1,i:
α
1
,
1
=
q
1
⋅
k
1
d
=
1
×
1
+
2
×
0
2
=
0.71
α
1
,
2
=
q
1
⋅
k
2
d
=
1
×
0
+
2
×
1
2
=
1.41
α_{1,1} = \frac{q^1 · k^1}{\sqrt{d}} = \frac{1×1+2×0}{\sqrt{2}} = 0.71 \\ α_{1,2} = \frac{q^1 · k^2}{\sqrt{d}} = \frac{1×0+2×1}{\sqrt{2}} = 1.41
α1,1=dq1⋅k1=21×1+2×0=0.71α1,2=dq1⋅k2=21×0+2×1=1.41
同理拿
q
2
q^2
q2去匹配所有的
k
k
k能得到
α
2
,
i
α_{2,i}
α2,i,统一写为矩阵乘法形式:
(
α
1
,
1
,
α
1
,
2
α
2
,
1
,
α
2
,
2
)
=
(
q
1
q
2
)
(
k
1
k
2
)
T
d
\binom{α_{1,1},α_{1,2}}{α_{2,1},α_{2,2}} = \frac{\left(\begin{smallmatrix} q^1 \\ q^2 \end{smallmatrix} \right) \left(\begin{smallmatrix} k^1 \\ k^2\end{smallmatrix} \right)^T}{\sqrt{d}}
(α2,1,α2,2α1,1,α1,2)=d(q1q2)(k1k2)T
然后对每个高维输入
a
i
a_i
ai对应的每一行,即
(
α
1
,
1
,
α
1
,
2
)
(α_{1,1}, α_{1,2})
(α1,1,α1,2)和
(
α
2
,
1
,
α
2
,
2
)
(α_{2,1}, α_{2,2})
(α2,1,α2,2)分别进行softmax处理得到
(
α
^
1
,
1
,
α
^
1
,
2
)
(\hatα_{1,1}, \hatα_{1,2})
(α^1,1,α^1,2)和
(
α
^
2
,
1
,
α
^
2
,
2
)
(\hatα_{2,1}, \hatα_{2,2})
(α^2,1,α^2,2),这里的
α
^
\hat{α}
α^相当于计算得到对于每个
v
v
v的权重(相当于给
v
v
v添加的注意力权重)。此时已经计算得到
softmax
(
Q
K
T
d
k
)
\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)
softmax(dkQKT)部分。
最后还需要将计算得到的权重
α
^
\hatα
α^乘以对应的
v
v
v,进行加权后得到的值为
b
1
=
α
^
1
,
1
×
v
1
+
α
^
1
,
2
×
v
2
=
(
0.33
,
0.67
)
b
2
=
α
^
2
,
1
×
v
1
+
α
^
2
,
2
×
v
2
=
(
0.50
,
0.50
)
b_1 = \hatα_{1,1} × v^1 + \hatα_{1,2} × v^2 = (0.33,0.67) \\ b_2 = \hatα_{2,1} × v^1 + \hatα_{2,2} × v^2 = (0.50,0.50)
b1=α^1,1×v1+α^1,2×v2=(0.33,0.67)b2=α^2,1×v1+α^2,2×v2=(0.50,0.50)
Multi-Head Attention
Multi-Head Attention基于Self-Attention模块进行搭建,首先和Self-Attention模块相同将 α i α_i αi分别通过 W q W_q Wq, W k W_k Wk, W v W_v Wv得到对应的 q i q^i qi, k i k^i ki, v i v^i vi,但不同的是此时需要根据多头的个数将 q i q^i qi, k i k^i ki, v i v^i vi拆分为 h h h份。例如当有两个注意力头时( h = 2 h=2 h=2),则 q 1 q^1 q1会被顺序拆分为 q 1 , 1 q^{1,1} q1,1和 q 1 , 2 q^{1,2} q1,2,对应的 q 1 , 1 q^{1,1} q1,1代表 h e a d 1 head1 head1, q 1 , 2 q^{1,2} q1,2代表 h e a d 2 head2 head2。注意这里生成的 q 1 , 1 q^{1,1} q1,1和 q 1 , 2 q^{1,2} q1,2直接由 q 1 q^1 q1按照 h e a d _ s i z e head\_size head_size顺序拆分即可。
然后将所有的 q i , 1 q^{i,1} qi,1, k i , 1 k^{i,1} ki,1, v i , 1 v^{i,1} vi,1分别组合为 h e a d i head_i headi对应的 Q i Q_i Qi, K i K_i Ki, V i V_i Vi,例如 { q 1 , 1 , q 2 , 1 , . . . , q n , 1 } \{q^{1,1},q^{2,1},...,q^{n,1}\} {q1,1,q2,1,...,qn,1}组合为 h e a d 1 head_1 head1对应的 Q Q Q。
然后针对每个
h
e
a
d
head
head使用和Self-Attention中相同的方法即可得到对应的结果。
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention} (Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dkQKT)V
然后将 h e a d head head得到的结果进行concat拼接。例如将 h e a d 1 head_1 head1中的 b 1 , 1 b_{1,1} b1,1(对应 α 1 α_1 α1)和 b 2 , 1 b_{2,1} b2,1(对应 α 2 α_2 α2)拼接在一起, h e a d 2 head_2 head2中的 b 1 , 2 b_{1,2} b1,2(对应 α 1 α_1 α1)和 b 2 , 2 b_{2,2} b2,2(对应 α 2 α_2 α2)拼接在一起。
接着将拼接后的结果通过映射 W O W^O WO(可学习参数)进行融合,如下图所示,融合后得到最终的结果 b 1 , b 2 b_1,b_2 b1,b2。
Multi-Head Attention
的内容总结下来就是论文中的两个公式:
M
u
l
t
i
H
e
a
d
(
Q
,
K
,
V
)
=
C
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
h
)
W
O
w
h
e
r
e
h
e
a
d
i
=
A
t
t
e
n
t
i
o
n
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
\text MultiHead(Q,K,V) = \text Concat(head_1,...,head_h)W^O \\ \text where head_i = \text Attention(QW_i^Q,KW_i^K,VW_i^V)
MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)