transfomer中attention为什么要除以根号d_k

简介

得到矩阵 Q, K, V之后就可以计算出 Self-Attention 的输出了,计算的公式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=Softmax(dk QKT)V

好处

除以维度的开方,可以将数据向0方向集中,使得经过softmax后的梯度更大.
从数学上分析,可以使得QK的分布和Q/K保持一致,

推导

对于两个独立的正态分布而言,两者的加法的期望和方差就是两个独立分布的期望和方差。
qk_T的计算过程为[len_q,dim][dim,len_k]=[len_q,len_k],qk的元素等于dim个乘积的和。对于0-1分布表乘积不会影响期望和方差,但是求和操作会使得方差乘以dim,因此对qk元素除以sqrt(dim)把标准差压回1.

这里展示一个不严谨的采样可视化过程
假设在query在(0,1)分布,key在(0,1)分布,随机采样lengthdim个点,然后统计querykey_T的散点的分布

import math
import numpy as np
import matplotlib.pyplot as plt


def plot_curve(mu=0, sigma =1):
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import norm
    # 设置正态分布的参数
    # mu, sigma = 0, 1  # 均值和标准差
    # 创建一个x值的范围,覆盖正态分布的整个区间
    x = np.linspace(mu - 4 * sigma, mu + 4 * sigma, 1000)
    # 计算对应的正态分布的概率密度值
    y = norm.pdf(x, mu, sigma)
    # 我们可以选择y值较高的点来绘制散点图,以模拟概率密度的分布
    # 这里我们可以设置一个阈值,只绘制y值大于某个值的点
    threshold = 0.01  # 可以根据需要调整这个阈值
    selected_points = y > threshold
    plt.plot(x, y, 'r-', lw=2, label='Normal dist. (mu={}, sigma={})'.format(mu, sigma))
    plt.title('Normal Distribution Scatter Approximation')
    plt.xlabel('Value')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_poins(x):
    # 因为这是一个一维的正态分布,我们通常只绘制x轴上的点
    # 但为了模拟二维散点图,我们可以简单地将y轴设置为与x轴相同或固定值(例如0)
    y = np.zeros_like(x)
    # 绘制散点图
    plt.figure(figsize=(8, 6))
    plt.scatter(x, y, alpha=0.5)  # alpha控制点的透明度
    plt.title('Normal (0, 1) Distribution Scatter Plot')
    plt.xlabel('Value')
    plt.ylabel('Value (or Frequency if binned)')
    plt.grid(True)
    plt.show()



if __name__ == '__main__':
    # 设置随机种子以便结果可复现
    np.random.seed(0)
    len = 10000
    dim = 100
    query = np.random.normal(0, 1, len*dim).reshape(len,dim)
    key = np.random.normal(0, 1, len*dim).reshape(dim,len)
    qk = np.matmul(query,key) / math.sqrt(dim)

    mean_query = query.mean()
    std_query = np.std(query,ddof=1)

    mean_key = key.mean()
    std_key = np.std(key,ddof=1)

    mean_qk = qk.mean()
    std_qk = np.std(qk,ddof=1)

    plot_poins(query)
    plot_curve(mean_query,std_query)

在这里插入图片描述

### Transformer 模型中的根号 d 的含义与作用 在 Transformer 模型中,自注意力机制(Self-Attention Mechanism)是一个核心组件。该机制通过计算查询向量 \( Q \) 和键向量 \( K \) 之间的点积来衡量不同位置的重要性权重。然而,在实际操作过程中,为了保持数值稳定性和梯度稳定性,通常会对这个点积的结果进行缩放处理。 具体来说,当计算 Query (Q) 和 Key (K) 矩阵之间点乘时,会得到一个大小为 \( n \times n \) 的矩阵,其中每一项代表两个词之间的相似程度得分。如果直接使用这些原始分数,则可能导致激活函数饱和以及训练不稳定等问题。因此引入了一个缩放因子 \( \frac{1}{\sqrt{d_k}} \),这里 \( d_k \) 表示 key 向量维度[^3]。 这种做法可以防止随着输入长度增加而导致的方差增大问题,并有助于维持合理的分布范围内的输出值。此外,这样的调整也有助于加速收敛并提高最终性能表现[^4]。 ```python import math import torch from torch import Tensor def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor: """ 计算缩放点积注意力. 参数: query (Tensor): 查询张量形状为 [batch_size, seq_len_q, depth]. key (Tensor): 键张量形状为 [batch_size, seq_len_k, depth]. value (Tensor): 值张量形状为 [batch_size, seq_len_v, depth]. 返回: context_vector (Tensor): 上下文向量. """ matmul_qk = torch.matmul(query, key.transpose(-2, -1)) # 点积运算 dk = float(key.shape[-1]) logits = matmul_qk / math.sqrt(dk) # 缩放处理 attention_weights = torch.softmax(logits, dim=-1) # 应用 Softmax 函数获取注意力权重 output = torch.matmul(attention_weights, value) # 加权求和得到上下文表示 return output ``` 通过对点积结果除以 \( \sqrt{d_k} \),使得模型能够在更广泛的场景下稳健工作,同时也促进了更好的泛化能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

知其所以然也

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值