在计算注意力时,特别是在使用缩放点积注意力(Scaled Dot-Product Attention)时,确实会用到除以维度的平方根。本文详细这一步操作的意义和原因。
假设我们有查询向量
Q
Q
Q、键向量
K
K
K 和值向量
V
V
V,它们的维度为
d
k
d_k
dk。注意力计算的公式为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V
Attention(Q,K,V)=softmax(dkQKT)V
意义:
- 防止过大的点积导致梯度消失: 直接计算
Q
K
T
Q K^T
QKT 可能导致结果的范围过大,特别是当
d
k
d_k
dk 较大时,点积的值可能会迅速增大,从而使得 softmax 的输出趋向于极值,导致梯度消失。这种情况下,模型的学习会变得不稳定。
- Softmax的输出定义为: softmax ( z i ) = e z i ∑ j e z j \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}} softmax(zi)=∑jezjezi当某个 z i z_i zi 的值很大(例如,远大于其他 z j z_j zj 的值),则 softmax ( z i ) \text{softmax}(z_i) softmax(zi) 将接近于1,而其他 softmax ( z j ) \text{softmax}(z_j) softmax(zj) ( j ≠ i j \neq i j=i)将接近于0。Softmax的梯度计算涉及到输出的概率分布。对于某个特定的输出 i i i,其梯度公式为: ∂ softmax ( z i ) ∂ z k = softmax ( z i ) ( δ i k − softmax ( z k ) ) \frac{\partial \text{softmax}(z_i)}{\partial z_k} = \text{softmax}(z_i) \left( \delta_{ik} - \text{softmax}(z_k) \right) ∂zk∂softmax(zi)=softmax(zi)(δik−softmax(zk))其中 δ i k \delta_{ik} δik 是Kronecker delta。
- 当 z i z_i zi 很大时, softmax ( z i ) \text{softmax}(z_i) softmax(zi) 接近1,而 softmax ( z k ) \text{softmax}(z_k) softmax(zk)( k ≠ i k \neq i k=i)接近0,因此梯度大约为: ∂ softmax ( z i ) ∂ z k ≈ 1 ⋅ ( 1 − 1 ) = 0 (when k = i ) \frac{\partial \text{softmax}(z_i)}{\partial z_k} \approx 1 \cdot (1 - 1) = 0 \quad \text{(when \( k = i \))} ∂zk∂softmax(zi)≈1⋅(1−1)=0(when k=i) ∂ softmax ( z k ) ∂ z k ≈ 1 ⋅ ( 0 − 0 ) = 0 (when k ≠ i ) \frac{\partial \text{softmax}(z_k)}{\partial z_k} \approx 1 \cdot (0 - 0) = 0 \quad \text{(when \( k \neq i \))} ∂zk∂softmax(zk)≈1⋅(0−0)=0(when k=i)
- 当
z
i
z_i
zi 很小时,
softmax
(
z
i
)
\text{softmax}(z_i)
softmax(zi) 接近0,因此梯度大约为:
∂
softmax
(
z
i
)
∂
z
k
≈
0
⋅
(
1
−
0
/
1
)
=
0
(when
k
=
i
)
\frac{\partial \text{softmax}(z_i)}{\partial z_k} \approx 0 \cdot (1 - 0/1) = 0 \quad \text{(when \( k = i \))}
∂zk∂softmax(zi)≈0⋅(1−0/1)=0(when k=i)
∂
softmax
(
z
k
)
∂
z
k
≈
0
⋅
(
0
−
0
/
1
)
=
0
(when
k
≠
i
)
\frac{\partial \text{softmax}(z_k)}{\partial z_k} \approx 0 \cdot (0 - 0/1) = 0 \quad \text{(when \( k \neq i \))}
∂zk∂softmax(zk)≈0⋅(0−0/1)=0(when k=i)由此可见,当softmax输出集中在某一项时,导数的计算会导致梯度几乎为0,这被称为梯度消失。这样的情况会使得反向传播过程中对相应参数的更新变得非常小,从而导致学习过程缓慢或停滞。
因此,当 softmax 输出极端化时,导致的梯度变小,会使得模型在训练过程中难以有效地学习和更新权重,影响整体性能。
- 平衡各个元素的影响: 除以 d k \sqrt{d_k} dk 有助于标准化点积的结果,使得不同维度的输入对最终输出的影响保持一致,从而提高模型的训练效果。
- 确保softmax的有效性: 如果
d
k
d_k
dk 较大,点积
Q
K
T
Q K^T
QKT 的结果可能会非常大。假设
Q
K
T
Q K^T
QKT 的值是一个较大的数
z
z
z,则 softmax 的计算为:
softmax
(
z
i
)
=
e
z
i
∑
j
e
z
j
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}
softmax(zi)=∑jezjezi当
z
i
z_i
zi 的值很大时,指数函数
e
z
i
e^{z_i}
ezi 将迅速增大,而其他较小的
z
j
z_j
zj 的影响几乎可以忽略不计。这将导致 softmax 输出接近于0或1,从而使得注意力权重极端化,失去分布的有效性。通过缩放,可以避免softmax函数的输入值过大或过小,确保计算出的注意力权重在合理范围内,从而使得模型能够有效地分配注意力。
因此,使用 d k \sqrt{d_k} dk 作为缩放因子,有助于提高计算的稳定性和效率。
为什么是除以 d k \sqrt{d_k} dk 而不是其他呢 ?
点积的方差: 在神经网络训练过程中,权重初始化和输入数据通常是随机的。在这个背景下,经过训练后的 Q Q Q 和 K K K 向量在某种程度上可以被视为随机向量,因为它们从随机初始化中演变而来,并且通过训练过程相互独立地学习。在计算 Q K T Q K^T QKT 时,点积的结果随着向量维度的增加而增大。具体来说,两个独立的随机向量的点积的期望值是 d k d_k dk(当其元素均为零均值时)。因此,点积的方差大约与 d k d_k dk 成正比。通过除以 d k \sqrt{d_k} dk,可以将结果的标准差(方差的平方根)缩放回合理范围,从而控制输出的稳定性。
以下是除以 d k \sqrt{d_k} dk的详细证明
假设 Q Q Q中的一项为 q q q, K K K中的一项为 k k k, Q Q Q和 K K K分别有 d k d_k dk项。根据上面的假设, q q q和 k k k两个变量的均值的方差均为0和1,且相互独立:
- E [ q ] = 0 \mathbb{E}[q] = 0 E[q]=0, Var ( q ) = 1 \text{Var}(q) = 1 Var(q)=1
- E [ k ] = 0 \mathbb{E}[k] = 0 E[k]=0, Var ( k ) = 1 \text{Var}(k) = 1 Var(k)=1
计算 Var ( q k ) \text{Var}(qk) Var(qk)
根据方差的定义:
Var
(
q
k
)
=
E
[
(
q
k
)
2
]
−
(
E
[
q
k
]
)
2
\text{Var}(qk) = \mathbb{E}[(qk)^2] - (\mathbb{E}[qk])^2
Var(qk)=E[(qk)2]−(E[qk])2
1. 计算 E [ q k ] \mathbb{E}[qk] E[qk]
如果
q
q
q 和
k
k
k 独立:
E
[
q
k
]
=
E
[
q
]
⋅
E
[
k
]
=
0
⋅
0
=
0
\mathbb{E}[qk] = \mathbb{E}[q] \cdot \mathbb{E}[k] = 0 \cdot 0 = 0
E[qk]=E[q]⋅E[k]=0⋅0=0
所以:
(
E
[
q
k
]
)
2
=
0
2
=
0
(\mathbb{E}[qk])^2 = 0^2 = 0
(E[qk])2=02=0
2. 计算 E [ ( q k ) 2 ] \mathbb{E}[(qk)^2] E[(qk)2]
(
q
k
)
2
=
q
2
k
2
(qk)^2 = q^2 k^2
(qk)2=q2k2
如果
q
q
q 和
k
k
k 独立,则:
E
[
(
q
k
)
2
]
=
E
[
q
2
]
⋅
E
[
k
2
]
\mathbb{E}[(qk)^2] = \mathbb{E}[q^2] \cdot \mathbb{E}[k^2]
E[(qk)2]=E[q2]⋅E[k2]
对于方差,我们知道:
E
[
q
2
]
=
Var
(
q
)
+
(
E
[
q
]
)
2
=
1
+
0
2
=
1
\mathbb{E}[q^2] = \text{Var}(q) + (\mathbb{E}[q])^2 = 1 + 0^2 = 1
E[q2]=Var(q)+(E[q])2=1+02=1
E
[
k
2
]
=
Var
(
k
)
+
(
E
[
k
]
)
2
=
1
+
0
2
=
1
\mathbb{E}[k^2] = \text{Var}(k) + (\mathbb{E}[k])^2 = 1 + 0^2 = 1
E[k2]=Var(k)+(E[k])2=1+02=1
因此:
E
[
(
q
k
)
2
]
=
E
[
q
2
]
⋅
E
[
k
2
]
=
1
⋅
1
=
1
\mathbb{E}[(qk)^2] = \mathbb{E}[q^2] \cdot \mathbb{E}[k^2] = 1 \cdot 1 = 1
E[(qk)2]=E[q2]⋅E[k2]=1⋅1=1
3. 计算方差
将结果代入方差公式:
Var
(
q
k
)
=
E
[
(
q
k
)
2
]
−
(
E
[
q
k
]
)
2
=
1
−
0
=
1
\text{Var}(qk) = \mathbb{E}[(qk)^2] - (\mathbb{E}[qk])^2 = 1 - 0 = 1
Var(qk)=E[(qk)2]−(E[qk])2=1−0=1
结果
因此,两个变量
q
q
q 和
k
k
k 相乘的方差为:
Var
(
q
k
)
=
1
\text{Var}(qk) = 1
Var(qk)=1
进一步计算方差 Var ( Q K T ) \text{Var}(QK^T) Var(QKT)
给定:
Q
=
[
q
1
,
q
2
,
…
,
q
d
k
]
(维度为
(
1
,
d
k
)
)
Q = [q_1, q_2, \ldots, q_{d_k}] \quad \text{(维度为 } (1, d_k)\text{)}
Q=[q1,q2,…,qdk](维度为 (1,dk))
K
=
[
k
1
,
k
2
,
…
,
k
d
k
]
(维度为
(
1
,
d
k
)
)
K = [k_1, k_2, \ldots, k_{d_k}] \quad \text{(维度为 } (1, d_k)\text{)}
K=[k1,k2,…,kdk](维度为 (1,dk))
计算外积:
Q
K
T
=
∑
i
=
1
d
k
q
i
k
i
Q K^T = \sum_{i=1}^{d_k} q_i k_i
QKT=i=1∑dkqiki
- 计算
E
[
Q
K
T
]
\mathbb{E}[QK^T]
E[QKT]
假设 q i q_i qi 和 k i k_i ki 是独立的:
E [ Q K T ] = E [ ∑ i = 1 d k q i k i ] = ∑ i = 1 d k E [ q i k i ] = ∑ i = 1 d k E [ q i ] ⋅ E [ k i ] = ∑ i = 1 d k 0 ⋅ 0 = 0 \mathbb{E}[QK^T] = \mathbb{E}\left[\sum_{i=1}^{d_k} q_i k_i\right] = \sum_{i=1}^{d_k} \mathbb{E}[q_i k_i] = \sum_{i=1}^{d_k} \mathbb{E}[q_i] \cdot \mathbb{E}[k_i] = \sum_{i=1}^{d_k} 0 \cdot 0 = 0 E[QKT]=E[i=1∑dkqiki]=i=1∑dkE[qiki]=i=1∑dkE[qi]⋅E[ki]=i=1∑dk0⋅0=0 - 计算
E
[
(
Q
K
T
)
2
]
\mathbb{E}[(QK^T)^2]
E[(QKT)2]
展开:
E [ ( Q K T ) 2 ] = E [ ( ∑ i = 1 d k q i k i ) 2 ] \mathbb{E}[(QK^T)^2] = \mathbb{E}\left[\left(\sum_{i=1}^{d_k} q_i k_i\right)^2\right] E[(QKT)2]=E (i=1∑dkqiki)2
使用协方差:
E [ ( Q K T ) 2 ] = ∑ i = 1 d k E [ q i 2 ] E [ k i 2 ] + ∑ i ≠ j E [ q i k i ] E [ q j k j ] \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] + \sum_{i \neq j} \mathbb{E}[q_i k_i] \mathbb{E}[q_j k_j] E[(QKT)2]=i=1∑dkE[qi2]E[ki2]+i=j∑E[qiki]E[qjkj]
假设 q i q_i qi 和 k j k_j kj 是独立的:
E [ q i k j ] = 0 (对于 i ≠ j ) \mathbb{E}[q_i k_j] = 0 \text{ (对于 } i \neq j\text{)} E[qikj]=0 (对于 i=j)
所以:
E [ ( Q K T ) 2 ] = ∑ i = 1 d k E [ q i 2 ] E [ k i 2 ] \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] E[(QKT)2]=i=1∑dkE[qi2]E[ki2]
已知:
E [ q i 2 ] = 1 , E [ k i 2 ] = 1 \mathbb{E}[q_i^2] = 1, \quad \mathbb{E}[k_i^2] = 1 E[qi2]=1,E[ki2]=1
因此:
E [ ( Q K T ) 2 ] = ∑ i = 1 d k 1 ⋅ 1 = d k \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} 1 \cdot 1 = d_k E[(QKT)2]=i=1∑dk1⋅1=dk - 计算方差
方差公式为:
Var ( Q K T ) = E [ ( Q K T ) 2 ] − ( E [ Q K T ] ) 2 \text{Var}(QK^T) = \mathbb{E}[(QK^T)^2] - (\mathbb{E}[QK^T])^2 Var(QKT)=E[(QKT)2]−(E[QKT])2
Var ( Q K T ) = d k − 0 2 = d k \text{Var}(QK^T) = d_k - 0^2 = d_k Var(QKT)=dk−02=dk
结果
因此,最终得到:
Var
(
Q
K
T
)
=
d
k
\text{Var}(QK^T) = d_k
Var(QKT)=dk
综上,在计算 Q K T Q K^T QKT 时,点积的结果的期望值是 d k d_k dk。通过除以 d k \sqrt{d_k} dk,可以将结果的标准差(方差的平方根)缩放为1,从而控制输出的稳定性。