Scaled Dot-Product Attention详解
背景
Scaled Dot-Product Attention 是一种注意力机制,由 Attention Is All You Need 一文中提出,其中点积会被 d k \sqrt{d_{k}} dk 缩放。具体来说,我们有一个查询向量 Q Q Q、一个键向量 K K K 和一个值向量 V V V,注意力的计算方式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dkQKT)V
- Query (Q): 可以将查询视为在特定时间步骤 t t t 的单词的表示。它类似于一个问题,用来检查与序列中其他单词的兼容性。
- Key (K): 键是我们用来检查与查询兼容性的标记。它就像是查询所提出问题的答案。
- Value (V): 值是标记的实际表示向量。它类似于与单词相关联的有意义的信息或内容。
Attention计算
1. 线性变换(Linear Transformations)
当我们将一个
e
m
b
e
d
d
i
n
g
embedding
embedding序列传递到Transformer的输入(包括encoder和decoder)时,首先发生的步骤是对每个
e
m
b
e
d
d
i
n
g
embedding
embedding进行三次独立的线性变换,从而生成三个向量——
q
u
e
r
y
query
query、
k
e
y
key
key和
v
a
l
u
e
value
value。这些变换通过将输入向量(
e
m
b
e
d
d
i
n
g
embedding
embedding)与三个权重矩阵相乘来实现(权重矩阵
W
Q
W^Q
WQ、
W
K
W^K
WK、
W
V
W^V
WV的参数是在模型训练过程中学习得到的)。
下图展示了一个序列长度为2,embedding大小为4的向量示例:
图片来源:The Illustrated Transformer by Jay Alammar
注意:此处仅为简化版的说明,真实情况下在对 X X X进行 e m b e d d i n g embedding embedding处理时还需要加上Positional Embeding等操作。
2. Scaled Dot-Product Attention计算
在获得
Q
Q
Q、
K
K
K、
V
V
V矩阵后,我们便可以进行Scaled Dot-Product Attention的计算了。
图片来源:The Illustrated Transformer by Jay Alammar
步骤
- 首先,取出一个单词的
query vector
,并与序列中每个单词(包括它自己)的key vector
的转置进行dot product
,得到的结果称为attention score
或attention weight
(即, Q K T QK^{T} QKT)。 - 接着,将获得的每个
attention score
除以key vector
维度的平方根( d k \sqrt{d_{k}} dk),得到缩放后的注意力分数(scaled attention score)。为什么要通过 d k d_{k} dk 对点积进行缩放?
通过 d k d_{k} dk 对点积进行缩放是为了应对在 d k d_{k} dk 取较大值时,点积的幅度变得很大这一问题。当点积的幅度变大时,softmax 函数会被推入梯度极小的区域,这会导致梯度消失问题,从而影响模型的训练效果。通过对点积进行缩放,可以将其幅度控制在一个合理的范围内,从而避免梯度过小的问题,确保模型能够有效地学习。 - 将缩放后的注意力分数通过 s o f t m a x softmax softmax函数处理,得到概率分布(注意力权重),确保所有值都在0到1之间。
- 最后,取出每个
value vector
,与 s o f t m a x softmax softmax函数的输出结果进行dot product
,得到最终输出。
样例
为了便于理解,此处样例为简化版,序列长度仅为3,维度仅为4。
Q
=
[
0.212
0.04
0.63
0.36
0.1
0.14
0.86
0.77
0.31
0.36
0.19
0.72
]
,
K
=
[
0.31
0.84
0.963
0.57
0.45
0.94
0.73
0.58
0.36
0.83
0.1
0.38
]
,
V
=
[
0.36
0.83
0.1
0.38
0.31
0.36
0.19
0.72
0.31
0.84
0.963
0.57
]
Q = \begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ 0.1 & 0.14 & 0.86 & 0.77\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix}, \ K = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ 0.45 & 0.94 & 0.73 & 0.58\\ 0.36 & 0.83 & 0.1 & 0.38 \end{bmatrix}, \ V = \begin{bmatrix} 0.36 & 0.83 & 0.1 & 0.38\\ 0.31 & 0.36 & 0.19 & 0.72\\ 0.31 & 0.84 & 0.963 & 0.57 \end{bmatrix}
Q=
0.2120.10.310.040.140.360.630.860.190.360.770.72
, K=
0.310.450.360.840.940.830.9630.730.10.570.580.38
, V=
0.360.310.310.830.360.840.10.190.9630.380.720.57
-
Dot Product:
Q K T = [ 0.212 0.04 0.63 0.36 0.1 0.14 0.86 0.77 0.31 0.36 0.19 0.72 ] [ 0.31 0.45 0.36 0.84 0.94 0.83 0.963 0.73 0.1 0.57 0.58 0.38 ] = [ 0.91121 0.8017 0.30932 1.41568 1.251 0.5308 0.99187 1.0342 0.703 ] QK^T=\begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ 0.1 & 0.14 & 0.86 & 0.77\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix} \begin{bmatrix} 0.31 & 0.45 & 0.36 \\ 0.84 & 0.94 & 0.83 \\ 0.963 & 0.73 & 0.1 \\ 0.57 & 0.58 & 0.38 \end{bmatrix} =\begin{bmatrix} 0.91121 & 0.8017 & 0.30932\\ 1.41568 & 1.251 & 0.5308\\ 0.99187 & 1.0342 & 0.703 \end{bmatrix} QKT= 0.2120.10.310.040.140.360.630.860.190.360.770.72 0.310.840.9630.570.450.940.730.580.360.830.10.38 = 0.911211.415680.991870.80171.2511.03420.309320.53080.703 -
Scale ( d k = 4 d_{k}=4 dk=4):
Q K T d k = [ 0.91121 0.8017 0.30932 1.41568 1.251 0.5308 0.99187 1.0342 0.703 ] × 1 4 = [ 0.455605 0.40085 0.15466 0.70784 0.6255 0.2654 0.495935 0.5171 0.3515 ] \frac{QK^T}{\sqrt{d_k}}=\begin{bmatrix} 0.91121 & 0.8017 & 0.30932\\ 1.41568 & 1.251 & 0.5308\\ 0.99187 & 1.0342 & 0.703 \end{bmatrix} \times \frac{1}{\sqrt{4}} =\begin{bmatrix} 0.455605 & 0.40085 & 0.15466\\ 0.70784 & 0.6255 & 0.2654\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} dkQKT= 0.911211.415680.991870.80171.2511.03420.309320.53080.703 ×41= 0.4556050.707840.4959350.400850.62550.51710.154660.26540.3515 -
Softmax
s o f t m a x ( Q K T d k ) = e z i ∑ j = i K e z j = [ 0.372 0.352 0.275 0.39 0.359 0.251 0.346 0.354 0.3 ] softmax(\frac{QK^T}{\sqrt{d_k}})=\frac{e^{z_i}}{\sum^K_{j=i}e^{z_j}}=\begin{bmatrix} 0.372 & 0.352 & 0.275\\ 0.39 & 0.359 & 0.251\\ 0.346 & 0.354 & 0.3 \end{bmatrix} softmax(dkQKT)=∑j=iKezjezi= 0.3720.390.3460.3520.3590.3540.2750.2510.3 -
Scaled Dot-Product Attention
s o f t m a x ( Q K T d k ) × V = [ 0.372 0.352 0.275 0.39 0.359 0.251 0.346 0.354 0.3 ] [ 0.36 0.83 0.1 0.38 0.31 0.36 0.19 0.72 0.31 0.84 0.963 0.57 ] = [ 0.32829 0.66648 0.368905 0.55155 0.3295 0.66378 0.348923 0.54975 0.3273 0.66662 0.39076 0.55736 ] softmax(\frac{QK^T}{\sqrt{d_k}}) \times V=\begin{bmatrix} 0.372 & 0.352 & 0.275\\ 0.39 & 0.359 & 0.251\\ 0.346 & 0.354 & 0.3 \end{bmatrix} \begin{bmatrix} 0.36 & 0.83 & 0.1 & 0.38\\ 0.31 & 0.36 & 0.19 & 0.72\\ 0.31 & 0.84 & 0.963 & 0.57 \end{bmatrix}= \begin{bmatrix} 0.32829 & 0.66648 & 0.368905 & 0.55155\\ 0.3295 & 0.66378 & 0.348923 & 0.54975\\ 0.3273 & 0.66662 & 0.39076 & 0.55736 \end{bmatrix} softmax(dkQKT)×V= 0.3720.390.3460.3520.3590.3540.2750.2510.3 0.360.310.310.830.360.840.10.190.9630.380.720.57 = 0.328290.32950.32730.666480.663780.666620.3689050.3489230.390760.551550.549750.55736
大家如果感觉有帮助可以点赞👍+收藏⭐️,也可以在评论区一起分享讨论!
参考
- Attention Is All You Need by Ashish Vaswani, et al
- What is the intuition behind the dot product attention? by Educative
- Understanding Attention In Transformers Models by Alvaro Henriquez
- The Illustrated Transformer by Jay Alammar
- Self-Attention: A step-by-step guide to calculating the context vector by Lovelyn David