Transformer在计算softmax之前为什么要除以维度的开方

在计算注意力时,特别是在使用缩放点积注意力(Scaled Dot-Product Attention)时,确实会用到除以维度的平方根。本文详细这一步操作的意义和原因。

假设我们有查询向量 Q Q Q、键向量 K K K 和值向量 V V V,它们的维度为 d k d_k dk。注意力计算的公式为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V
意义:

  • 防止过大的点积导致梯度消失: 直接计算 Q K T Q K^T QKT 可能导致结果的范围过大,特别是当 d k d_k dk 较大时,点积的值可能会迅速增大,从而使得 softmax 的输出趋向于极值,导致梯度消失。这种情况下,模型的学习会变得不稳定。
    • Softmax的输出定义为: softmax ( z i ) = e z i ∑ j e z j \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}} softmax(zi)=jezjezi当某个 z i z_i zi 的值很大(例如,远大于其他 z j z_j zj 的值),则 softmax ( z i ) \text{softmax}(z_i) softmax(zi) 将接近于1,而其他 softmax ( z j ) \text{softmax}(z_j) softmax(zj) j ≠ i j \neq i j=i)将接近于0。Softmax的梯度计算涉及到输出的概率分布。对于某个特定的输出 i i i,其梯度公式为: ∂ softmax ( z i ) ∂ z k = softmax ( z i ) ( δ i k − softmax ( z k ) ) \frac{\partial \text{softmax}(z_i)}{\partial z_k} = \text{softmax}(z_i) \left( \delta_{ik} - \text{softmax}(z_k) \right) zksoftmax(zi)=softmax(zi)(δiksoftmax(zk))其中 δ i k \delta_{ik} δik 是Kronecker delta。
    • z i z_i zi 很大时, softmax ( z i ) \text{softmax}(z_i) softmax(zi) 接近1,而 softmax ( z k ) \text{softmax}(z_k) softmax(zk) k ≠ i k \neq i k=i)接近0,因此梯度大约为: ∂ softmax ( z i ) ∂ z k ≈ 1 ⋅ ( 1 − 1 ) = 0 (when  k = i ) \frac{\partial \text{softmax}(z_i)}{\partial z_k} \approx 1 \cdot (1 - 1) = 0 \quad \text{(when \( k = i \))} zksoftmax(zi)1(11)=0(when k=i) ∂ softmax ( z k ) ∂ z k ≈ 1 ⋅ ( 0 − 0 ) = 0 (when  k ≠ i ) \frac{\partial \text{softmax}(z_k)}{\partial z_k} \approx 1 \cdot (0 - 0) = 0 \quad \text{(when \( k \neq i \))} zksoftmax(zk)1(00)=0(when k=i)
    • z i z_i zi 很小时, softmax ( z i ) \text{softmax}(z_i) softmax(zi) 接近0,因此梯度大约为: ∂ softmax ( z i ) ∂ z k ≈ 0 ⋅ ( 1 − 0 / 1 ) = 0 (when  k = i ) \frac{\partial \text{softmax}(z_i)}{\partial z_k} \approx 0 \cdot (1 - 0/1) = 0 \quad \text{(when \( k = i \))} zksoftmax(zi)0(10/1)=0(when k=i) ∂ softmax ( z k ) ∂ z k ≈ 0 ⋅ ( 0 − 0 / 1 ) = 0 (when  k ≠ i ) \frac{\partial \text{softmax}(z_k)}{\partial z_k} \approx 0 \cdot (0 - 0/1) = 0 \quad \text{(when \( k \neq i \))} zksoftmax(zk)0(00/1)=0(when k=i)由此可见,当softmax输出集中在某一项时,导数的计算会导致梯度几乎为0,这被称为梯度消失。这样的情况会使得反向传播过程中对相应参数的更新变得非常小,从而导致学习过程缓慢或停滞。
      因此,当 softmax 输出极端化时,导致的梯度变小,会使得模型在训练过程中难以有效地学习和更新权重,影响整体性能。
  1. 平衡各个元素的影响: 除以 d k \sqrt{d_k} dk 有助于标准化点积的结果,使得不同维度的输入对最终输出的影响保持一致,从而提高模型的训练效果。
  2. 确保softmax的有效性: 如果 d k d_k dk 较大,点积 Q K T Q K^T QKT 的结果可能会非常大。假设 Q K T Q K^T QKT 的值是一个较大的数 z z z,则 softmax 的计算为: softmax ( z i ) = e z i ∑ j e z j \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}} softmax(zi)=jezjezi z i z_i zi 的值很大时,指数函数 e z i e^{z_i} ezi 将迅速增大,而其他较小的 z j z_j zj 的影响几乎可以忽略不计。这将导致 softmax 输出接近于0或1,从而使得注意力权重极端化,失去分布的有效性。通过缩放,可以避免softmax函数的输入值过大或过小,确保计算出的注意力权重在合理范围内,从而使得模型能够有效地分配注意力。
    因此,使用 d k \sqrt{d_k} dk 作为缩放因子,有助于提高计算的稳定性和效率。
     

为什么是除以 d k \sqrt{d_k} dk 而不是其他呢 ?

 

点积的方差: 在神经网络训练过程中,权重初始化和输入数据通常是随机的。在这个背景下,经过训练后的 Q Q Q K K K 向量在某种程度上可以被视为随机向量,因为它们从随机初始化中演变而来,并且通过训练过程相互独立地学习。在计算 Q K T Q K^T QKT 时,点积的结果随着向量维度的增加而增大。具体来说,两个独立的随机向量的点积的期望值是 d k d_k dk(当其元素均为零均值时)。因此,点积的方差大约与 d k d_k dk 成正比。通过除以 d k \sqrt{d_k} dk ,可以将结果的标准差(方差的平方根)缩放回合理范围,从而控制输出的稳定性。

以下是除以 d k \sqrt{d_k} dk 的详细证明

假设 Q Q Q中的一项为 q q q K K K中的一项为 k k k Q Q Q K K K分别有 d k d_k dk项。根据上面的假设, q q q k k k两个变量的均值的方差均为0和1,且相互独立:

  • E [ q ] = 0 \mathbb{E}[q] = 0 E[q]=0, Var ( q ) = 1 \text{Var}(q) = 1 Var(q)=1
  • E [ k ] = 0 \mathbb{E}[k] = 0 E[k]=0, Var ( k ) = 1 \text{Var}(k) = 1 Var(k)=1

计算 Var ( q k ) \text{Var}(qk) Var(qk)

根据方差的定义:
Var ( q k ) = E [ ( q k ) 2 ] − ( E [ q k ] ) 2 \text{Var}(qk) = \mathbb{E}[(qk)^2] - (\mathbb{E}[qk])^2 Var(qk)=E[(qk)2](E[qk])2

1. 计算 E [ q k ] \mathbb{E}[qk] E[qk]

如果 q q q k k k 独立:
E [ q k ] = E [ q ] ⋅ E [ k ] = 0 ⋅ 0 = 0 \mathbb{E}[qk] = \mathbb{E}[q] \cdot \mathbb{E}[k] = 0 \cdot 0 = 0 E[qk]=E[q]E[k]=00=0
所以:
( E [ q k ] ) 2 = 0 2 = 0 (\mathbb{E}[qk])^2 = 0^2 = 0 (E[qk])2=02=0

2. 计算 E [ ( q k ) 2 ] \mathbb{E}[(qk)^2] E[(qk)2]

( q k ) 2 = q 2 k 2 (qk)^2 = q^2 k^2 (qk)2=q2k2
如果 q q q k k k 独立,则:
E [ ( q k ) 2 ] = E [ q 2 ] ⋅ E [ k 2 ] \mathbb{E}[(qk)^2] = \mathbb{E}[q^2] \cdot \mathbb{E}[k^2] E[(qk)2]=E[q2]E[k2]
对于方差,我们知道:
E [ q 2 ] = Var ( q ) + ( E [ q ] ) 2 = 1 + 0 2 = 1 \mathbb{E}[q^2] = \text{Var}(q) + (\mathbb{E}[q])^2 = 1 + 0^2 = 1 E[q2]=Var(q)+(E[q])2=1+02=1
E [ k 2 ] = Var ( k ) + ( E [ k ] ) 2 = 1 + 0 2 = 1 \mathbb{E}[k^2] = \text{Var}(k) + (\mathbb{E}[k])^2 = 1 + 0^2 = 1 E[k2]=Var(k)+(E[k])2=1+02=1
因此:
E [ ( q k ) 2 ] = E [ q 2 ] ⋅ E [ k 2 ] = 1 ⋅ 1 = 1 \mathbb{E}[(qk)^2] = \mathbb{E}[q^2] \cdot \mathbb{E}[k^2] = 1 \cdot 1 = 1 E[(qk)2]=E[q2]E[k2]=11=1

3. 计算方差

将结果代入方差公式:
Var ( q k ) = E [ ( q k ) 2 ] − ( E [ q k ] ) 2 = 1 − 0 = 1 \text{Var}(qk) = \mathbb{E}[(qk)^2] - (\mathbb{E}[qk])^2 = 1 - 0 = 1 Var(qk)=E[(qk)2](E[qk])2=10=1

结果

因此,两个变量 q q q k k k 相乘的方差为:
Var ( q k ) = 1 \text{Var}(qk) = 1 Var(qk)=1

进一步计算方差 Var ( Q K T ) \text{Var}(QK^T) Var(QKT)

给定:
Q = [ q 1 , q 2 , … , q d k ] (维度为  ( 1 , d k ) ) Q = [q_1, q_2, \ldots, q_{d_k}] \quad \text{(维度为 } (1, d_k)\text{)} Q=[q1,q2,,qdk](维度为 (1,dk))
K = [ k 1 , k 2 , … , k d k ] (维度为  ( 1 , d k ) ) K = [k_1, k_2, \ldots, k_{d_k}] \quad \text{(维度为 } (1, d_k)\text{)} K=[k1,k2,,kdk](维度为 (1,dk))
计算外积:
Q K T = ∑ i = 1 d k q i k i Q K^T = \sum_{i=1}^{d_k} q_i k_i QKT=i=1dkqiki

  1. 计算 E [ Q K T ] \mathbb{E}[QK^T] E[QKT]
    假设 q i q_i qi k i k_i ki 是独立的:
    E [ Q K T ] = E [ ∑ i = 1 d k q i k i ] = ∑ i = 1 d k E [ q i k i ] = ∑ i = 1 d k E [ q i ] ⋅ E [ k i ] = ∑ i = 1 d k 0 ⋅ 0 = 0 \mathbb{E}[QK^T] = \mathbb{E}\left[\sum_{i=1}^{d_k} q_i k_i\right] = \sum_{i=1}^{d_k} \mathbb{E}[q_i k_i] = \sum_{i=1}^{d_k} \mathbb{E}[q_i] \cdot \mathbb{E}[k_i] = \sum_{i=1}^{d_k} 0 \cdot 0 = 0 E[QKT]=E[i=1dkqiki]=i=1dkE[qiki]=i=1dkE[qi]E[ki]=i=1dk00=0
  2. 计算 E [ ( Q K T ) 2 ] \mathbb{E}[(QK^T)^2] E[(QKT)2]
    展开:
    E [ ( Q K T ) 2 ] = E [ ( ∑ i = 1 d k q i k i ) 2 ] \mathbb{E}[(QK^T)^2] = \mathbb{E}\left[\left(\sum_{i=1}^{d_k} q_i k_i\right)^2\right] E[(QKT)2]=E (i=1dkqiki)2
    使用协方差:
    E [ ( Q K T ) 2 ] = ∑ i = 1 d k E [ q i 2 ] E [ k i 2 ] + ∑ i ≠ j E [ q i k i ] E [ q j k j ] \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] + \sum_{i \neq j} \mathbb{E}[q_i k_i] \mathbb{E}[q_j k_j] E[(QKT)2]=i=1dkE[qi2]E[ki2]+i=jE[qiki]E[qjkj]
    假设 q i q_i qi k j k_j kj 是独立的:
    E [ q i k j ] = 0  (对于  i ≠ j ) \mathbb{E}[q_i k_j] = 0 \text{ (对于 } i \neq j\text{)} E[qikj]=0 (对于 i=j)
    所以:
    E [ ( Q K T ) 2 ] = ∑ i = 1 d k E [ q i 2 ] E [ k i 2 ] \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] E[(QKT)2]=i=1dkE[qi2]E[ki2]
    已知:
    E [ q i 2 ] = 1 , E [ k i 2 ] = 1 \mathbb{E}[q_i^2] = 1, \quad \mathbb{E}[k_i^2] = 1 E[qi2]=1,E[ki2]=1
    因此:
    E [ ( Q K T ) 2 ] = ∑ i = 1 d k 1 ⋅ 1 = d k \mathbb{E}[(QK^T)^2] = \sum_{i=1}^{d_k} 1 \cdot 1 = d_k E[(QKT)2]=i=1dk11=dk
  3. 计算方差
    方差公式为:
    Var ( Q K T ) = E [ ( Q K T ) 2 ] − ( E [ Q K T ] ) 2 \text{Var}(QK^T) = \mathbb{E}[(QK^T)^2] - (\mathbb{E}[QK^T])^2 Var(QKT)=E[(QKT)2](E[QKT])2
    Var ( Q K T ) = d k − 0 2 = d k \text{Var}(QK^T) = d_k - 0^2 = d_k Var(QKT)=dk02=dk

结果

因此,最终得到:
Var ( Q K T ) = d k \text{Var}(QK^T) = d_k Var(QKT)=dk

综上,在计算 Q K T Q K^T QKT 时,点积的结果的期望值是 d k d_k dk。通过除以 d k \sqrt{d_k} dk ,可以将结果的标准差(方差的平方根)缩放为1,从而控制输出的稳定性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值