缩放点积注意力(Scaled Dot-Product Attention)

缩放点积注意力(Scaled Dot-Product Attention)

缩放点积注意力(Scaled Dot-Product Attention)是自注意力(Self-Attention)机制的一种变体,它被广泛应用于现代的神经网络架构中,尤其是在 Transformer 中。它的核心思想是利用输入序列中各个位置的 查询(Query)键(Key)值(Value) 来计算注意力权重,并通过加权求和的方式生成上下文向量。

数学原理

对于给定的查询 Q Q Q、键 K K K 和值 V V V,缩放点积注意力通过以下步骤计算上下文向量:

  1. 计算点积:首先,计算查询和键之间的点积,得到一个注意力得分矩阵。该矩阵表示了每个查询向量与所有键向量之间的相似度。

    scores = Q K T \text{scores} = QK^T scores=QKT

    其中:

    • Q Q Q 是查询矩阵,维度为 ( n queries , d k ) (n_{\text{queries}}, d_k) (nqueries,dk)
    • K K K 是键矩阵,维度为 ( n keys , d k ) (n_{\text{keys}}, d_k) (nkeys,dk)
    • d k d_k dk 是每个查询和键向量的维度。
  2. 缩放(Scaling):由于点积的结果会随着向量维度的增大而增大,这会导致梯度消失或者梯度爆炸等问题,因此通过除以一个常数 d k \sqrt{d_k} dk 来进行缩放。这个缩放操作可以帮助避免点积值过大。

    scaled_scores = Q K T d k \text{scaled\_scores} = \frac{QK^T}{\sqrt{d_k}} scaled_scores=dk QKT

  3. 应用 softmax:接下来,通过 softmax 函数对每个查询的得分进行归一化,得到注意力权重矩阵。这些权重矩阵会表明每个查询向量对所有键向量的关注程度。

    attention_weights = softmax ( Q K T d k ) \text{attention\_weights} = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) attention_weights=softmax(dk QKT)

  4. 加权求和值:最后,将这些注意力权重与值矩阵 V V V 相乘,得到加权后的值,生成上下文向量。

    output = attention_weights ⋅ V \text{output} = \text{attention\_weights} \cdot V output=attention_weightsV

最终公式总结

将以上步骤综合起来,缩放点积注意力的计算过程如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QKT)V

步骤详细解释

  1. 查询(Query)、键(Key)和值(Value):在计算注意力时,查询、键和值通常来自同一个输入数据。查询用于表示当前的兴趣点,而键和值则用于提供信息。注意力权重将决定如何从值中提取信息。

  2. 缩放操作:缩放操作的目的是通过 d k \sqrt{d_k} dk 来避免随着向量维度增大,点积值过大导致的数值不稳定性。

  3. softmax:softmax 操作确保了每个查询的注意力权重和为 1,这样可以通过加权平均的方式得到每个值的贡献。

  4. 加权求和:通过将注意力权重与值矩阵相乘,最终得到一个上下文向量,表示该查询所关注的信息。

应用场景

缩放点积注意力是 Transformer 架构的基础,它广泛应用于:

  • 机器翻译:Transformer 是当前自然语言处理(NLP)领域的主流模型,它依赖于自注意力机制来并行处理输入序列中的信息。
  • 图像处理:在 Vision Transformer(ViT)中,缩放点积注意力用于处理图像数据。
  • 语言模型:像 GPT(Generative Pretrained Transformer)和 BERT(Bidirectional Encoder Representations from Transformers)等模型都使用了缩放点积注意力来捕获句子中的长期依赖关系。

PyTorch 实现

下面是一个简单的缩放点积注意力的 PyTorch 实现:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V):
    """
    实现缩放点积注意力
    
    参数:
    Q (Tensor): 查询矩阵,形状为 (batch_size, num_queries, d_k)
    K (Tensor): 键矩阵,形状为 (batch_size, num_keys, d_k)
    V (Tensor): 值矩阵,形状为 (batch_size, num_keys, d_v)
    
    返回:
    Tensor: 输出上下文向量,形状为 (batch_size, num_queries, d_v)
    """
    # 计算 Q 和 K 的点积
    matmul_qk = torch.matmul(Q, K.transpose(-2, -1))  # (batch_size, num_queries, num_keys)
    
    # 缩放点积
    d_k = Q.size(-1)  # 获取 d_k
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # 使用 softmax 归一化
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    
    # 使用注意力权重加权求和值矩阵 V
    output = torch.matmul(attention_weights, V)  # (batch_size, num_queries, d_v)
    
    return output, attention_weights

# 假设我们有以下输入
batch_size = 2
num_queries = 3
num_keys = 4
d_k = 5
d_v = 6

Q = torch.randn(batch_size, num_queries, d_k)  # 查询矩阵
K = torch.randn(batch_size, num_keys, d_k)  # 键矩阵
V = torch.randn(batch_size, num_keys, d_v)  # 值矩阵

output, attention_weights = scaled_dot_product_attention(Q, K, V)

print("Output shape:", output.shape)  # 输出形状 (batch_size, num_queries, d_v)
print("Attention weights shape:", attention_weights.shape)  # 注意力权重形状 (batch_size, num_queries, num_keys)

总结

  • 缩放点积注意力通过计算查询和键之间的点积,缩放结果,应用 softmax 得到注意力权重,然后用这些权重加权求和值向量,得到上下文向量。
  • 缩放操作是其一个重要特点,它避免了点积值随维度增大而导致的数值不稳定。
  • Transformer 架构采用了缩放点积注意力,它在自然语言处理(NLP)、图像处理等领域都有广泛应用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值