1. 多头注意力(Multi‑Head Attention)原理
设输入序列表示为矩阵 X ∈ R B × L × d model X\in\mathbb{R}^{B\times L\times d_{\text{model}}} X∈RB×L×dmodel,其中
- B B B:批大小(batch size),
- L L L:序列长度(sequence length),
- d model d_{\text{model}} dmodel:模型隐层维度(model dimension)。
多头注意力基于对缩放点乘注意力的并行化扩展,引入了 h h h 个注意力头(heads),每个头在不同子空间中学习不同的表示。
1.1 线性映射与切分
我们首先为每个头定义三组可学习权重:
W
i
Q
∈
R
d
model
×
d
k
,
W
i
K
∈
R
d
model
×
d
k
,
W
i
V
∈
R
d
model
×
d
v
,
i
=
1
,
…
,
h
W_i^Q \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^K \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^V \in \mathbb{R}^{d_{\text{model}}\times d_v}, \quad i=1,\dots,h
WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv,i=1,…,h
其中
- h h h:头数(number of heads),
- d k d_k dk:每个头中 Query/Key 的维度(key/query dimension),
- d v d_v dv:每个头中 Value 的维度(value dimension),
- 通常 d model = h × d k d_{\text{model}}=h\times d_k dmodel=h×dk 且取 d v = d k d_v = d_k dv=dk。
对输入
X
X
X 进行投影,得到第
i
i
i 个头的查询、键、值:
Q
i
=
X
W
i
Q
,
K
i
=
X
W
i
K
,
V
i
=
X
W
i
V
Q_i = X\,W_i^Q,\quad K_i = X\,W_i^K,\quad V_i = X\,W_i^V
Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中
- Q i ∈ R B × L × d k Q_i \in \mathbb{R}^{B\times L\times d_k} Qi∈RB×L×dk,
- K i ∈ R B × L × d k K_i \in \mathbb{R}^{B\times L\times d_k} Ki∈RB×L×dk,
- V i ∈ R B × L × d v V_i \in \mathbb{R}^{B\times L\times d_v} Vi∈RB×L×dv。
1.2 缩放点乘注意力
对第 i i i 个头,分别对所有位置做点积注意力:
- 打分矩阵
S c o r e i = Q i K i ⊤ ∈ R B × L × L \mathrm{Score}_i = Q_i\,K_i^\top \quad\in\mathbb{R}^{B\times L\times L} Scorei=QiKi⊤∈RB×L×L - 缩放
S c o r e ~ i = S c o r e i d k \widetilde{\mathrm{Score}}_i = \frac{\mathrm{Score}_i}{\sqrt{d_k}} Score i=dkScorei - Softmax 归一化
A i = s o f t m a x ( S c o r e ~ i ) ∈ R B × L × L A_i = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}_i\bigr) \quad\in\mathbb{R}^{B\times L\times L} Ai=softmax(Score i)∈RB×L×L - 加权求和
h e a d i = A i V i ∈ R B × L × d v \mathrm{head}_i = A_i\,V_i \quad\in\mathbb{R}^{B\times L\times d_v} headi=AiVi∈RB×L×dv
1.3 拼接与线性变换
将所有头的输出在最后一维拼接,再做一次线性投影:
C
o
n
c
a
t
=
[
h
e
a
d
1
,
…
,
h
e
a
d
h
]
∈
R
B
×
L
×
(
h
d
v
)
\mathrm{Concat} = \bigl[\mathrm{head}_1,\dots,\mathrm{head}_h\bigr] \quad\in\mathbb{R}^{B\times L\times (h\,d_v)}
Concat=[head1,…,headh]∈RB×L×(hdv)
定义输出权重矩阵
W
O
∈
R
(
h
d
v
)
×
d
model
W^O\in\mathbb{R}^{(h\,d_v)\times d_{\text{model}}}
WO∈R(hdv)×dmodel
最终输出:
M
u
l
t
i
H
e
a
d
(
X
)
=
C
o
n
c
a
t
W
O
∈
R
B
×
L
×
d
model
\mathrm{MultiHead}(X) = \mathrm{Concat}\;W^O \quad\in\mathbb{R}^{B\times L\times d_{\text{model}}}
MultiHead(X)=ConcatWO∈RB×L×dmodel
2. PyTorch 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, h: int):
"""
d_model: 模型维度 d_model
h: 注意力头数 h
"""
super().__init__()
assert d_model % h == 0, "d_model 必须能被 h 整除"
self.d_model = d_model # d_model
self.h = h # h
self.d_k = d_model // h # d_k = d_model / h
self.d_v = self.d_k # d_v 通常等于 d_k
# 投影矩阵 W_i^Q, W_i^K, W_i^V,实际上合并为一个大矩阵后在 forward 再切分
self.w_q = nn.Linear(d_model, d_model) # 同时生成 h 个 Q 投影
self.w_k = nn.Linear(d_model, d_model) # 同时生成 h 个 K 投影
self.w_v = nn.Linear(d_model, d_model) # 同时生成 h 个 V 投影
# 输出线性变换 W^O
self.w_o = nn.Linear(d_model, d_model)
def forward(self, X: torch.Tensor, mask: torch.Tensor = None):
"""
X: 输入张量,形状 (B, L, d_model)
mask: 可选掩码,形状 (B, 1, L, L) 或 (B, L, L)
"""
B, L, _ = X.size()
# 1. 线性映射到 Q, K, V,然后切分 h 头
# 先得到 (B, L, h*d_k),再 view/transpose 为 (B, h, L, d_k)
Q = self.w_q(X).view(B, L, self.h, self.d_k).transpose(1, 2)
K = self.w_k(X).view(B, L, self.h, self.d_k).transpose(1, 2)
V = self.w_v(X).view(B, L, self.h, self.d_k).transpose(1, 2)
# 此时 Q, K, V 形状均为 (B, h, L, d_k)
# 2. 计算点积注意力
# scores = Q @ K^T -> (B, h, L, L)
scores = torch.matmul(Q, K.transpose(-2, -1))
# 缩放:除以 sqrt(d_k)
scores = scores / math.sqrt(self.d_k)
# 可选掩码:将被屏蔽位置设为 -inf
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax 归一化 -> (B, h, L, L)
A = F.softmax(scores, dim=-1)
# 加权求和 -> head_i 形状 (B, h, L, d_k)
heads = torch.matmul(A, V)
# 3. 拼接 h 个头:transpose 回 (B, L, h, d_k) 再 reshape
concat = heads.transpose(1, 2).contiguous().view(B, L, self.h * self.d_k)
# concat 形状 (B, L, h*d_k) == (B, L, d_model)
# 4. 最后一次线性变换 W^O
output = self.w_o(concat) # -> (B, L, d_model)
return output, A # 返回输出及注意力权重 A