多头注意力机制是Transformer架构中的一个关键部分。
多头注意力机制的优势在于它能够让模型同时关注输入数据的多个不同方面。例如,在处理文本数据时,一个头可能关注语法结构,而另一个头可能关注语义内容。通过这种方式,模型能够更全面地理解输入数据,从而提高其在各种任务中的性能。
工作流程:
-
输入准备:首先,有一组输入序列,这些序列可以是文本、图像或其他类型的数据,经过嵌入层转换为向量表示。
-
输入嵌入向量的分割:首先,输入的嵌入向量会被分割成多个较小的部分,每个部分对应一个注意力“头”。这些分割后的向量具有更低的维度,使得模型能够在更细粒度上学习数据的表示。例如,如果输入的嵌入向量维度是512,并且使用8个头,那么每个头将处理64维的向量。
-
线性变换生成Q、K、V:对输入向量应用三个不同的线性变换(即全连接层),分别生成查询(Query, Q)、键(Key, K)和值(Value, V)向量。这些变换允许模型学习到不同的表示,以更好地捕捉输入数据中的特征。
-
多头注意力计算:
在每个头上,使用缩放点积注意力计算公式来计算注意力权重。具体来说,对于每个头i,其计算公式可以表示为: 其中,(Q_i)、(K_i)和(V_i)分别是第i个头的查询、键和值向量,(d_k)是键向量的维度,用于缩放点积以稳定softmax函数。这个计算过程在所有头上并行进行。 -
输出合并:
计算完所有头的注意力后,将它们的输出向量在最后一个维度上拼接起来。然后,再次通过一个线性变换得到多头注意力机制的最终输出。这个过程将多个头的信息融合在一起,使得模型能够同时关注到输入数据的不同方面。
总的来说,多头注意力机制通过分割输入嵌入向量、生成QKV向量、并行计算多个头的注意力权重以及合并输出等步骤,有效地捕捉输入数据的复杂特征关系,提高了模型的表达能力和学习效率。这种机制在自然语言处理、图像识别等领域有着广泛的应用。
下面进入正题,来个具体的栗子~
手动执行多头注意力的关键步骤:
假设
- 嵌入维度
n_embd
: 8 - 头数
n_head
: 2 - 序列长度
T
: 3 - 批次大小
B
: 1
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 设定参数
n_embd = 8
n_head = 2
T = 3
B = 1
# 模拟一个输入张量 x,形状为 (B, T, n_embd)
x = torch.tensor([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
[2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]]], dtype=torch.float32)
# 初始化一个线性层来模拟 c_attn,输出维度是输入的三倍(因为 q, k, v)
c_attn = nn.Linear(n_embd, 3 * n_embd)
c_attn.weight.data = torch.randn(3 * n_embd, n_embd)
c_attn.bias.data = torch.zeros(3 * n_embd)
# 通过线性层得到 q, k, v
qkv = c_attn(x)
q, k, v = qkv.chunk(3, dim=-1)
# 将 q, k, v 分为多个头
q = q.view(B, T, n_head, -1).transpose(1, 2)
k = k.view(B, T, n_head, -1).transpose(1, 2)
v = v.view(B, T, n_head, -1).transpose(1, 2)
print("Query (q):\n", q)
print("Key (k):\n", k)
print("Value (v):\n", v)
# 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
print("Attention Scores:\n", attn_scores)
# 应用因果掩码(上三角设置为负无穷)
causal_mask = torch.tril(torch.ones(T, T), diagonal=0).bool()
attn_scores_masked = attn_scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(1).expand_as(attn_scores), float('-inf'))
print("Masked Attention Scores:\n", attn_scores_masked)
# Softmax 得到注意力权重
# 比如在翻译任务中,这些权重表示在生成当前词的翻译时,每个输入词的重要性。
attn_weights = F.softmax(attn_scores_masked, dim=-1)
print("Attention Weights:\n", attn_weights)
# 使用注意力权重和 Value 计算输出
output = torch.matmul(attn_weights, v)
print("Output:\n", output)
# 将输出转换回原始形状
output = output.transpose(1, 2).contiguous().view(B, T, n_embd)
print("Final Output:\n", output)