1. 前言
缩放点积注意力机制(scaled dot-product attention)是OpenAI的GPT系列大语言模型所使用的多头注意力机制(multi-head attention)的核心,其目标与前文所述简单自注意力机制完全相同,即输入向量序列 x 1 , x 2 , ⋯ , x n x_1, x_2, \cdots, x_n x1,x2,⋯,xn,计算context向量 z 1 , z 2 , ⋯ , z n z_1, z_2, \cdots, z_n z1,z2,⋯,zn。
缩放点积注意力机制计算context向量 z i z_i zi的流程与简单自注意力机制并不完全相同,在计算注意力分数及注意力权重之前,缩放点积注意力机制会使用3个参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv将输入向量 x i x_i xi分别映射成query向量 q i q_i qi,key向量 k i k_i ki以及value向量 v i v_i vi。缩放点积注意力机制使用query向量及key向量计算注意力分数,并用注意力权重对value向量加权求和计算context向量。
本文介绍缩放点积注意力机制生成context向量的计算流程,并实现继承自torch.nn.Module
的神经网络模块ScaledDotProductAttention
。
2. 缩放点积注意力机制
如下图所示,缩放点积注意力机制生成context向量的计算步骤如下:
- 计算qkv向量:使用输入向量 x i x_i xi分别点乘3个参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv,得到query向量 q i q_i qi,key向量 k i k_i ki以及value向量 v i v_i vi;
- 计算注意力分数:使用query及key向量的点积作为注意力分数;
- 计算注意力权重:将注意力分数 ω i \omega_i ωi除以key向量维度的平方根得到经过缩放的注意力分数 ω i s c a l e d \omega_i^{scaled} ωiscaled,然后将经过缩放的注意力分数归一化得到注意力权重;
- 计算context向量:使用注意力权重对value向量加权求和计算context向量。
2.1 计算qkv向量
如下图所示,缩放点积注意力机制计算输入向量 x 2 x_2 x2对应的context向量 z 2 z_2 z2的第一步是计算 x 2 x_2 x2对应的query向量 q 2 q_2 q2,以及所有输入向量 x i x_i xi对应的key及value向量 k i k_i ki和 v i v_i vi。随机初始化3个参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv,使用 x i x_i xi分别点乘 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv,可以得到 q i , k i , v i q_i, k_i, v_i qi,ki,vi。
参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv是神经网络的参数,随机初始化后在模型训练时更新。
缩放点积注意力机制中query,key和value借鉴了数据库领域的概念。
query与数据库检索时的查询信息类似,表示模型当前关注或试图理解的对象(即输入序列中的某个token)。
key类似于索引数据库记录的键,输入序列的每个token都有一个关联的key,用key与query的点积大小确定对输入序列中各个token的关注程度。
value类似于数据库中键值对的值,如果将输入序列的每个token视为数据库中的一条记录,则value可以视为各条数据库记录中有价值的信息。模型评估当前query与哪些key相关,然后检索这些key对应的value,构成包含上下文信息的context向量。
可以使用如下代码初始化参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv,并计算 x 2 x_2 x2对应的query向量 q 2 q_2 q2,以及所有输入向量 x i x_i xi对应的key及value向量 k i k_i ki和 v i v_i vi:
import torch
torch.manual_seed(123)
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
W_q = torch.nn.Parameter(torch.rand(d_in, d_out))
W_k = torch.nn.Parameter(torch.rand(d_in, d_out))
W_v = torch.nn.Parameter(torch.rand(d_in, d_out))
query_2 = x_2 @ W_q
keys = inputs @ W_k
values = inputs @ W_v
print("query_2:\n", query_2)
print("keys:\n", keys)
print("values:\n", values)
执行上面代码,打印结果如下:
query_2:
tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)
keys:
tensor([[0.3669, 0.7646],
[0.4433, 1.1419],
[0.4361, 1.1156],
[0.2408, 0.6706],
[0.1827, 0.3292],
[0.3275, 0.9642]], grad_fn=<MmBackward0>)
values:
tensor([[0.1855, 0.8812],
[0.3951, 1.0037],
[0.3879, 0.9831],
[0.2393, 0.5493],
[0.1492, 0.3346],
[0.3221, 0.7863]], grad_fn=<MmBackward0>)
在深度学习实践中,并不会使用torch.nn.Parameter(torch.rand(...))
的方式初始化参数矩阵
W
q
,
W
k
,
W
v
W_q, W_k, W_v
Wq,Wk,Wv,一般会直接使用PyTorch内置的torch.nn.Linear
层。torch.nn.Linear
层的数学模型为
y
=
W
x
+
b
y=Wx+b
y=Wx+b,可以令bais=False
,使数学模型简化为
y
=
W
x
y=Wx
y=Wx。torch.nn.Linear
层的参数矩阵初始化方法更好,可以使模型训练过程更稳定,模型更容易收敛。
可以使用如下代码初始化3个torch.nn.Linear
层W_query, W_key, W_value
,并计算query向量
q
2
q_2
q2,以及所有key向量和value向量:
W_query = torch.nn.Linear(d_in, d_out, bias=False)
W_key = torch.nn.Linear(d_in, d_out, bias=False)
W_value = torch.nn.Linear(d_in, d_out, bias=False)
query_2 = W_query(x_2)
keys = W_key(inputs)
values = W_value(inputs)
print("query_2:\n", query_2)
print("keys:\n", keys)
print("values:\n", values)
执行上面代码,打印结果如下:
query_2:
tensor([0.3558, 0.5643], grad_fn=<SqueezeBackward3>)
keys:
tensor([[-0.3132, -0.2272],
[-0.1536, 0.2768],
[-0.1574, 0.2865],
[-0.0360, 0.1826],
[-0.1805, 0.3798],
[-0.0080, 0.0967]], grad_fn=<MmBackward0>)
values:
tensor([[0.4772, 0.1063],
[0.6770, 0.4980],
[0.6763, 0.4946],
[0.3514, 0.3055],
[0.4736, 0.2954],
[0.3836, 0.3539]], grad_fn=<MmBackward0>)
torch.nn.Linear
层初始化参数矩阵的方法与torch.nn.Parameter(torch.rand(...))
不相同,因此输出了不同的qkv向量。
2.2 计算注意力分数
如下图所示,将query向量 q 2 q_2 q2分别点乘所有输入向量 x i x_i xi对应的key向量 k i k_i ki,得到注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,⋯,ω26。
可以使用如下代码将query向量 q 2 q_2 q2点乘所有key向量 k i k_i ki组成的矩阵keys,一次性批量计算出所有注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,⋯,ω26:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query q_2
print(attn_scores_2)
执行上面代码,打印结果如下:
tensor([-0.2396, 0.1015, 0.1057, 0.0902, 0.1501, 0.0518], grad_fn=<SqueezeBackward3>)
2.3 计算注意力权重
缩放点积注意力机制将注意力分数归一化得到注意力权重的方法与前文所述简单自注意力机制并不完全相同。其首先将每个注意力分数
ω
2
i
\omega_{2i}
ω2i除以key向量维度的平方根,得到经过缩放的注意力分数(scaled attention score)
ω
2
i
s
c
a
l
e
d
\omega_{2i}^{scaled}
ω2iscaled,然后再使用softmax
函数将经过缩放的注意力分数
ω
21
s
c
a
l
e
d
,
ω
22
s
c
a
l
e
d
,
⋯
,
ω
26
s
c
a
l
e
d
\omega_{21}^{scaled}, \omega_{22}^{scaled}, \cdots, \omega_{26}^{scaled}
ω21scaled,ω22scaled,⋯,ω26scaled归一化,得到注意力权重
α
21
,
α
22
,
⋯
,
α
26
\alpha_{21}, \alpha_{22}, \cdots, \alpha_{26}
α21,α22,⋯,α26。
key向量维度与query向量维度一定相同(只有两个维度相同的向量才可以做点积运算)。value向量维度等于生成context向量的维度,与key向量及query向量维度可以不相同。
假设输入向量为 x = [ x 1 , x 2 , ⋯ , x n ] x=[x_1, x_2, \cdots, x_n] x=[x1,x2,⋯,xn],将 x x x输入
softmax
函数,得到输出 y = [ y 1 , y 2 , ⋯ , y n ] y=[y_1, y_2, \cdots, y_n] y=[y1,y2,⋯,yn],则softmax
函数的导数 ∂ y i ∂ x j = y i ( δ i j − y j ) \frac{\partial y_i}{\partial x_j}=y_i(\delta_{ij}-y_j) ∂xj∂yi=yi(δij−yj)。其中 δ i j \delta_{ij} δij当 i = j i=j i=j时取值为1,否则为0。如下面的代码所示,当输入向量 x x x各个分量的数值比较小时,
softmax
函数输出的向量 y y y的各个分量数值分布比较均匀。输入向量 x x x各个分量的数值越大,则输出的向量 y y y中与输入向量 x x x最大数值对应的分量会越接近1,其余分量会越接近0:x = torch.tensor([0.1, 0.3, 0.5, 0.6, 0.9]) print(torch.softmax(x, dim=0)) print(torch.softmax(x * 10, dim=0)) print(torch.softmax(x * 100, dim=0))
执行上面代码,打印结果如下:
tensor([0.1318, 0.1610, 0.1966, 0.2173, 0.2933]) tensor([3.1325e-04, 2.3146e-03, 1.7103e-02, 4.6490e-02, 9.3378e-01]) tensor([1.8049e-35, 8.7565e-27, 4.2484e-18, 9.3577e-14, 1.0000e+00])
大语言模型中qkv向量的维度一般都特别大。注意力分数等于query及key向量的内积,向量维度越大,则注意力分数的数值越大。
softmax
函数的导数 ∂ y i ∂ x j = y i ( δ i j − y j ) \frac{\partial y_i}{\partial x_j}=y_i(\delta_{ij}-y_j) ∂xj∂yi=yi(δij−yj),如果直接将非常大的注意力分数输入softmax
函数,会使训练大语言模型时反向传播计算出大部分参数的梯度都接近0,严重降低模型训练效率,甚至会导致模型训练停滞,损失函数无法收敛。key向量的维度越大,计算出来的注意力分数会越大,key向量维度的平方根也会越大。将注意力分数除以key向量维度的平方根,可以使输入
softmax
函数的向量各个分量的数值相对小,训练大语言模型时反向传播计算出的参数梯度大小比较合适,模型比较容易收敛。这也是这种自注意力机制被称为缩放点积注意力机制的原因。
可以使用如下代码计算经过缩放的注意力分数,并将经过缩放的注意力分数归一化,得到注意力权重:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
执行上面代码,打印结果如下:
tensor([0.1359, 0.1730, 0.1735, 0.1716, 0.1790, 0.1670], grad_fn=<SoftmaxBackward0>)
2.4 计算context向量
缩放点积注意力机制使用注意力权重对value向量加权求和计算context向量,context向量 z 2 = ∑ i α 2 i v i z_2=\sum_i\alpha_{2i}v_i z2=∑iα2ivi。
可以使用如下代码计算context向量 z 2 z_2 z2:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
执行上面代码,打印结果如下:
tensor([0.5084, 0.3508], grad_fn=<SqueezeBackward3>)
3. 构建神经网络模块ScaledDotProductAttention
使用PyTorch构建神经网络模型或神经网络模型中的某个子模块需要实现一个torch.nn.Module
的子类,并重写__init__
构造方法及forward
方法。__init__
方法用于定义模型的结构,创建并初始化模型中的各个组件。forward
方法用于定义模型的前向计算流程,输入数据依次经过__init__
方法中定义的各个组件,得到模型的输出。
可以使用如下代码构建神经网络模块ScaledDotProductAttention
:
class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
queries = self.W_query(x)
keys = self.W_key(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
__init__
方法初始化了3个torch.nn.Linear
层。forward
方法将x
分别输入3个torch.nn.Linear
层,得到queries,keys及values。将queries与keys两个张量相乘,计算出所有注意力分数attn_scores,并根据2.3所述方法计算出注意力权重attn_weights。最后将attn_weights与values两个张量相乘,一次性批量计算出所有输入向量
x
i
x_i
xi对应的context向量context_vec。
可以使用如下代码实例化ScaledDotProductAttention
类对象,输入inputs
,计算context向量:
sdpa = ScaledDotProductAttention(d_in, d_out)
print(sdpa(inputs))
执行上面代码,打印结果如下:
tensor([[0.5322, 0.2491],
[0.5316, 0.2488],
[0.5316, 0.2488],
[0.5340, 0.2501],
[0.5331, 0.2497],
[0.5337, 0.2499]], grad_fn=<MmBackward0>)
4. 结束语
缩放点积注意力机制使用三个参数矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv将输入向量 x i x_i xi映射成 q i , k i , v i q_i, k_i, v_i qi,ki,vi,计算query与key向量的点积作为注意力分数,使用注意力权重对value向量加权求和计算context向量。
缩放点积注意力机制的精髓在于将注意力分数除以key向量维度的平方根,使输入softmax
函数的数值比较小,计算出的注意力权重分布比较合理,避免训练模型时反向传播计算部分参数的梯度接近零,从而提升模型训练效率。在深度学习领域,神经网络架构中类似这种数值处理(如Batch Normalization等等)绝大部分原因都是使反向传播时计算的参数梯度大小相对更合理,避免梯度消失及梯度爆炸问题,遇到这种数值处理设计从梯度方面思考原因一般不会错。这种数值处理看似不难,但是只有真正深度理解神经网络内部计算逻辑细节,才能做出这种设计,而不是有手就行,不真正理解根本想不到。
假如要我招聘面试技术细节,我就会问设计这种数值处理的原因。