手动执行多头注意力的关键步骤

多头注意力机制是Transformer架构中的一个关键部分。 

多头注意力机制的优势在于它能够让模型同时关注输入数据的多个不同方面。例如,在处理文本数据时,一个头可能关注语法结构,而另一个头可能关注语义内容。通过这种方式,模型能够更全面地理解输入数据,从而提高其在各种任务中的性能。

工作流程:

  1. 输入准备:首先,有一组输入序列,这些序列可以是文本、图像或其他类型的数据,经过嵌入层转换为向量表示。

  2. 输入嵌入向量的分割:首先,输入的嵌入向量会被分割成多个较小的部分,每个部分对应一个注意力“头”。这些分割后的向量具有更低的维度,使得模型能够在更细粒度上学习数据的表示。例如,如果输入的嵌入向量维度是512,并且使用8个头,那么每个头将处理64维的向量。

  3. 线性变换生成Q、K、V:对输入向量应用三个不同的线性变换(即全连接层),分别生成查询(Query, Q)、键(Key, K)和值(Value, V)向量。这些变换允许模型学习到不同的表示,以更好地捕捉输入数据中的特征。

  4. 多头注意力计算
    在每个头上,使用缩放点积注意力计算公式来计算注意力权重。具体来说,对于每个头i,其计算公式可以表示为:\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i      其中,(Q_i)、(K_i)和(V_i)分别是第i个头的查询、键和值向量,(d_k)是键向量的维度,用于缩放点积以稳定softmax函数。这个计算过程在所有头上并行进行。

  5. 输出合并
    计算完所有头的注意力后,将它们的输出向量在最后一个维度上拼接起来。然后,再次通过一个线性变换得到多头注意力机制的最终输出。这个过程将多个头的信息融合在一起,使得模型能够同时关注到输入数据的不同方面。

总的来说,多头注意力机制通过分割输入嵌入向量、生成QKV向量、并行计算多个头的注意力权重以及合并输出等步骤,有效地捕捉输入数据的复杂特征关系,提高了模型的表达能力和学习效率。这种机制在自然语言处理、图像识别等领域有着广泛的应用。

下面进入正题,来个具体的栗子~

手动执行多头注意力的关键步骤

假设

  • 嵌入维度 n_embd: 8
  • 头数 n_head: 2
  • 序列长度 T: 3
  • 批次大小 B: 1
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import math  
  
# 设定参数  
n_embd = 8  
n_head = 2  
T = 3  
B = 1  
  
# 模拟一个输入张量 x,形状为 (B, T, n_embd)  
x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],  
                   [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],  
                   [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]]], dtype=torch.float32)  
  
# 初始化一个线性层来模拟 c_attn,输出维度是输入的三倍(因为 q, k, v)  
c_attn = nn.Linear(n_embd, 3 * n_embd)  
c_attn.weight.data = torch.randn(3 * n_embd, n_embd)  
c_attn.bias.data = torch.zeros(3 * n_embd)  
  
# 通过线性层得到 q, k, v  
qkv = c_attn(x)  
q, k, v = qkv.chunk(3, dim=-1)  
  
# 将 q, k, v 分为多个头  
q = q.view(B, T, n_head, -1).transpose(1, 2)  
k = k.view(B, T, n_head, -1).transpose(1, 2)  
v = v.view(B, T, n_head, -1).transpose(1, 2)  
  
print("Query (q):\n", q)  
print("Key (k):\n", k)  
print("Value (v):\n", v)  
  
# 计算注意力分数  
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))  
print("Attention Scores:\n", attn_scores)  
  
# 应用因果掩码(上三角设置为负无穷)  
causal_mask = torch.tril(torch.ones(T, T), diagonal=0).bool()  
attn_scores_masked = attn_scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(1).expand_as(attn_scores), float('-inf'))  
print("Masked Attention Scores:\n", attn_scores_masked)  
  
# Softmax 得到注意力权重  
# 比如在翻译任务中,这些权重表示在生成当前词的翻译时,每个输入词的重要性。
attn_weights = F.softmax(attn_scores_masked, dim=-1)  
print("Attention Weights:\n", attn_weights)  
  
# 使用注意力权重和 Value 计算输出  
output = torch.matmul(attn_weights, v)  
print("Output:\n", output)  
  
# 将输出转换回原始形状  
output = output.transpose(1, 2).contiguous().view(B, T, n_embd)  
print("Final Output:\n", output)
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值