Scaled Dot-Product Attention详解

背景

Scaled Dot-Product Attention 是一种注意力机制,由 Attention Is All You Need 一文中提出,其中点积会被 d k \sqrt{d_{k}} dk 缩放。具体来说,我们有一个查询向量 Q Q Q、一个键向量 K K K 和一个值向量 V V V,注意力的计算方式如下:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QKT)V

  • Query (Q): 可以将查询视为在特定时间步骤 t t t 的单词的表示。它类似于一个问题,用来检查与序列中其他单词的兼容性。
  • Key (K): 键是我们用来检查与查询兼容性的标记。它就像是查询所提出问题的答案。
  • Value (V): 值是标记的实际表示向量。它类似于与单词相关联的有意义的信息或内容。

Attention计算

1. 线性变换(Linear Transformations)

当我们将一个 e m b e d d i n g embedding embedding序列传递到Transformer的输入(包括encoder和decoder)时,首先发生的步骤是对每个 e m b e d d i n g embedding embedding进行三次独立的线性变换,从而生成三个向量—— q u e r y query query k e y key key v a l u e value value。这些变换通过将输入向量 e m b e d d i n g embedding embedding)与三个权重矩阵相乘来实现(权重矩阵 W Q W^Q WQ W K W^K WK W V W^V WV的参数是在模型训练过程中学习得到的)。
下图展示了一个序列长度为2,embedding大小为4的向量示例:

图片来源:The Illustrated Transformer by Jay Alammar

注意:此处仅为简化版的说明,真实情况下在对 X X X进行 e m b e d d i n g embedding embedding处理时还需要加上Positional Embeding等操作。

2. Scaled Dot-Product Attention计算

在获得 Q Q Q K K K V V V矩阵后,我们便可以进行Scaled Dot-Product Attention的计算了。
在这里插入图片描述

图片来源:The Illustrated Transformer by Jay Alammar

步骤

  1. 首先,取出一个单词的query vector,并与序列中每个单词(包括它自己)的key vector的转置进行dot product,得到的结果称为attention scoreattention weight(即, Q K T QK^{T} QKT)。
  2. 接着,将获得的每个attention score除以key vector维度的平方根( d k \sqrt{d_{k}} dk ),得到缩放后的注意力分数(scaled attention score)。

    为什么要通过 d k d_{k} dk 对点积进行缩放?
    通过 d k d_{k} dk 对点积进行缩放是为了应对在 d k d_{k} dk 取较大值时,点积的幅度变得很大这一问题。当点积的幅度变大时,softmax 函数会被推入梯度极小的区域,这会导致梯度消失问题,从而影响模型的训练效果。通过对点积进行缩放,可以将其幅度控制在一个合理的范围内,从而避免梯度过小的问题,确保模型能够有效地学习。

  3. 将缩放后的注意力分数通过 s o f t m a x softmax softmax函数处理,得到概率分布(注意力权重),确保所有值都在0到1之间。
  4. 最后,取出每个value vector,与 s o f t m a x softmax softmax函数的输出结果进行dot product,得到最终输出。

样例

为了便于理解,此处样例为简化版,序列长度仅为3,维度仅为4。
Q = [ 0.212 0.04 0.63 0.36 0.1 0.14 0.86 0.77 0.31 0.36 0.19 0.72 ] ,   K = [ 0.31 0.84 0.963 0.57 0.45 0.94 0.73 0.58 0.36 0.83 0.1 0.38 ] ,   V = [ 0.36 0.83 0.1 0.38 0.31 0.36 0.19 0.72 0.31 0.84 0.963 0.57 ] Q = \begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ 0.1 & 0.14 & 0.86 & 0.77\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix}, \ K = \begin{bmatrix} 0.31 & 0.84 & 0.963 & 0.57\\ 0.45 & 0.94 & 0.73 & 0.58\\ 0.36 & 0.83 & 0.1 & 0.38 \end{bmatrix}, \ V = \begin{bmatrix} 0.36 & 0.83 & 0.1 & 0.38\\ 0.31 & 0.36 & 0.19 & 0.72\\ 0.31 & 0.84 & 0.963 & 0.57 \end{bmatrix} Q= 0.2120.10.310.040.140.360.630.860.190.360.770.72 , K= 0.310.450.360.840.940.830.9630.730.10.570.580.38 , V= 0.360.310.310.830.360.840.10.190.9630.380.720.57

  1. Dot Product:
    Q K T = [ 0.212 0.04 0.63 0.36 0.1 0.14 0.86 0.77 0.31 0.36 0.19 0.72 ] [ 0.31 0.45 0.36 0.84 0.94 0.83 0.963 0.73 0.1 0.57 0.58 0.38 ] = [ 0.91121 0.8017 0.30932 1.41568 1.251 0.5308 0.99187 1.0342 0.703 ] QK^T=\begin{bmatrix} 0.212 & 0.04 & 0.63 & 0.36\\ 0.1 & 0.14 & 0.86 & 0.77\\ 0.31 & 0.36 & 0.19 & 0.72 \end{bmatrix} \begin{bmatrix} 0.31 & 0.45 & 0.36 \\ 0.84 & 0.94 & 0.83 \\ 0.963 & 0.73 & 0.1 \\ 0.57 & 0.58 & 0.38 \end{bmatrix} =\begin{bmatrix} 0.91121 & 0.8017 & 0.30932\\ 1.41568 & 1.251 & 0.5308\\ 0.99187 & 1.0342 & 0.703 \end{bmatrix} QKT= 0.2120.10.310.040.140.360.630.860.190.360.770.72 0.310.840.9630.570.450.940.730.580.360.830.10.38 = 0.911211.415680.991870.80171.2511.03420.309320.53080.703

  2. Scale ( d k = 4 d_{k}=4 dk=4):
    Q K T d k = [ 0.91121 0.8017 0.30932 1.41568 1.251 0.5308 0.99187 1.0342 0.703 ] × 1 4 = [ 0.455605 0.40085 0.15466 0.70784 0.6255 0.2654 0.495935 0.5171 0.3515 ] \frac{QK^T}{\sqrt{d_k}}=\begin{bmatrix} 0.91121 & 0.8017 & 0.30932\\ 1.41568 & 1.251 & 0.5308\\ 0.99187 & 1.0342 & 0.703 \end{bmatrix} \times \frac{1}{\sqrt{4}} =\begin{bmatrix} 0.455605 & 0.40085 & 0.15466\\ 0.70784 & 0.6255 & 0.2654\\ 0.495935 & 0.5171 & 0.3515 \end{bmatrix} dk QKT= 0.911211.415680.991870.80171.2511.03420.309320.53080.703 ×4 1= 0.4556050.707840.4959350.400850.62550.51710.154660.26540.3515

  3. Softmax
    s o f t m a x ( Q K T d k ) = e z i ∑ j = i K e z j = [ 0.372 0.352 0.275 0.39 0.359 0.251 0.346 0.354 0.3 ] softmax(\frac{QK^T}{\sqrt{d_k}})=\frac{e^{z_i}}{\sum^K_{j=i}e^{z_j}}=\begin{bmatrix} 0.372 & 0.352 & 0.275\\ 0.39 & 0.359 & 0.251\\ 0.346 & 0.354 & 0.3 \end{bmatrix} softmax(dk QKT)=j=iKezjezi= 0.3720.390.3460.3520.3590.3540.2750.2510.3

  4. Scaled Dot-Product Attention
    s o f t m a x ( Q K T d k ) × V = [ 0.372 0.352 0.275 0.39 0.359 0.251 0.346 0.354 0.3 ] [ 0.36 0.83 0.1 0.38 0.31 0.36 0.19 0.72 0.31 0.84 0.963 0.57 ] = [ 0.32829 0.66648 0.368905 0.55155 0.3295 0.66378 0.348923 0.54975 0.3273 0.66662 0.39076 0.55736 ] softmax(\frac{QK^T}{\sqrt{d_k}}) \times V=\begin{bmatrix} 0.372 & 0.352 & 0.275\\ 0.39 & 0.359 & 0.251\\ 0.346 & 0.354 & 0.3 \end{bmatrix} \begin{bmatrix} 0.36 & 0.83 & 0.1 & 0.38\\ 0.31 & 0.36 & 0.19 & 0.72\\ 0.31 & 0.84 & 0.963 & 0.57 \end{bmatrix}= \begin{bmatrix} 0.32829 & 0.66648 & 0.368905 & 0.55155\\ 0.3295 & 0.66378 & 0.348923 & 0.54975\\ 0.3273 & 0.66662 & 0.39076 & 0.55736 \end{bmatrix} softmax(dk QKT)×V= 0.3720.390.3460.3520.3590.3540.2750.2510.3 0.360.310.310.830.360.840.10.190.9630.380.720.57 = 0.328290.32950.32730.666480.663780.666620.3689050.3489230.390760.551550.549750.55736


大家如果感觉有帮助可以点赞👍+收藏⭐️,也可以在评论区一起分享讨论!


参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值