本文导读:Multi-Query Attention(MQA)是 Google Research 2022 年提出的一项轻量化注意力技术,通过“多查询、单键值”的设计,把自注意力层的 KV 缓存从 O(h·n·d) 降到 O(n·d),在不牺牲模型精度的前提下大幅节省显存与带宽。如今 Falcon-40B、ChatGLM2-6B、Llama-3-Instruct 等热门开源模型均默认开启 MQA。本文以“原理 → 数学推导 → 代码实践 → 典型模型 → 优缺点”的路线,系统梳理 MQA 的来龙去脉,并给出 PyTorch / Transformers 的落地示例,帮助你一步上手。
摘要
Multi-Query Attention 通过共享 Key / Value、仅为每个头保留独立 Query,使注意力计算的时间复杂度不变、显存使用与 I/O 成本成倍下降;在 GPT-NeoX-20B 长序列基准中将推理速度提升 30-40%,显存削减约 60%。
1 痛点:多头注意力的 KV 爆炸
多头注意力把隐藏维 d 均分成 h 个头,每个头都要持有一份 K 与 V。在自回归推理阶段,需要把所有历史 token 的 KV 保存在 GPU 显存中:
当 h = 32、n = 8 K、d=4 096 时,仅 KV 就超过 8 GB。 这直接限制了长上下文能力与并发数。
2 原理:多查询、单键值
2.1 设计思想
-
只保留 h 份 Query:保持头部多样性;
-
共享 1 份 Key / Value:删除冗余拷贝。
这样 KV cache 从 h 倍 缩到 1 倍,注意力得分公式变为
计算 FLOPs 与 dense attention 完全一致。
2.2 数学推导
设隐藏维 ,序列长 n:
实现 | Key / Value 形状 | 显存复杂度 |
---|---|---|
多头 (MHA) | | |
多查询 (MQA) | |
节省比例约 1/h。当 h=32 时,显存下降 31 ×。
3 代码实践:PyTorch & Transformers
from transformers import AutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained("tiiuae/falcon-7b")
config.multi_query = True # ① 打开 MQA
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-7b",
config=config,
torch_dtype="auto",
device_map="auto")
Hugging Face ≥ v4.35 在 Falcon, Llama-3, ChatGLM2 等权重中已内置 MQA;对于自定义模型,可在 nn.MultiheadAttention 前手动复制查询、共享 KV 并改写前向传播。源码参考 modeling_RW.py。
下面给出一个基于 GPT-style Decoder-Only 架构的 Multi-Query Attention 伪代码示例。该实现思路如下:
伪代码(gpt风格)
def multi_query_attention(X, Wq, Wkv, mask):
"""
X: [B, T, D] 输入隐藏状态
Wq: [D, H * d_h] 查询投影
Wkv: [D, 2 * d_h] 键值投影(Key 和 Value 共享)
mask: [T, T] 因果掩码,下三角为 True,上三角为 False
返回: [B, T, D] 注意力输出
"""
B, T, D = X.shape
H = num_heads
d_h = D // H
# 1. 计算多头查询 Q: [B, T, H, d_h]
# 先线性映射 -> [B, T, H*d_h] -> reshape
Q = X @ Wq # [B, T, H*d_h]
Q = Q.reshape(B, T, H, d_h) # [B, T, H, d_h]
# 2. 计算共享的 K, V: [B, T, 1, d_h] 各一份
KV = X @ Wkv # [B, T, 2*d_h]
K_shared, V_shared = split(KV, 2, axis=-1) # 各 [B, T, d_h]
# 为方便多头计算,插入头维度大小=1
K = K_shared.reshape(B, T, 1, d_h) # [B, T, 1, d_h]
V = V_shared.reshape(B, T, 1, d_h) # [B, T, 1, d_h]
# 3. 计算注意力分数并加掩码
# scores = Q @ K^T / sqrt(d_h) => [B, H, T, T]
# mask 后 softmax -> weights
sqrt_d = math.sqrt(d_h)
# 先转置 K 以便矩阵乘
K_t = K.permute(0, 2, 3, 1) # [B, 1, d_h, T]
# Q: [B, T, H, d_h] -> permute -> [B, H, T, d_h]
Q_t = Q.permute(0, 2, 1, 3) # [B, H, T, d_h]
scores = (Q_t @ K_t) / sqrt_d # [B, H, T, T]
# 应用因果掩码(把上三角置为 -inf)
scores = scores.masked_fill(~mask[None, None, :, :], -inf)
weights = softmax(scores, axis=-1) # [B, H, T, T]
# 4. 加权 V 得到每头输出
# weights [B, H, T, T] 乘以 V [B, T, 1, d_h]
# 先 reshape V 以对齐: [B, 1, T, d_h]
V_t = V.permute(0, 2, 1, 3) # [B, 1, T, d_h]
# 输出 head_out: [B, H, T, d_h]
head_out = weights @ V_t # [B, H, T, d_h]
# 5. 拼回原始维度
# head_out -> [B, T, H, d_h] -> reshape [B, T, D]
head_out = head_out.permute(0, 2, 1, 3) # [B, T, H, d_h]
out = head_out.reshape(B, T, D) # [B, T, D]
return out
说明
-
Wq 将每个位置的向量映射成 H 份 Query,而 Wkv 只生成一份 Key/Value。
-
mask 是一个下三角布尔矩阵,用于保证自回归生成仅访问前序位置。
-
各头共享同一份 K、V,但各自有独立的 Q,可并行计算。
整合到 GPT Block
在 GPT-Decoder Block 中,只需将原本的 MHA 换成上面 multi_query_attention,其余残差、LayerNorm、FFN 等保持不变:
def gpt_block(X, params):
# 1. LayerNorm 前归一化
X_norm = LayerNorm(X)
# 2. Multi-Query Attention
attn_out = multi_query_attention(
X_norm,
params.Wq,
params.Wkv,
causal_mask(X.shape[1])
)
# 3. 残差连接
X = X + attn_out
# 4. LayerNorm + 前馈 FFN
Y = LayerNorm(X)
ffn_out = FeedForward(Y, params.ffn)
X = X + ffn_out
return X
如此,即可在 GPT-类模型中原地启用 Multi-Query Attention,实现 KV 去复用、显存节省和推理提速。
4 典型模型与实测收益
模型 | 参数 | 采用 MQA | 长序推理显存↓ | 吞吐↑ | 来源 |
---|---|---|---|---|---|
Falcon-40B | 40 B | 默认 | -60 % | +35 % | |
ChatGLM2-6B | 6 B | 默认 | -50 % | +42 % | |
Llama-3-Instruct-8B | 8 B | 默认 | -58 % | +33 % |
5 与 FlashAttention 的协同
FlashAttention 负责 块化读写 + SRAM 缓存,而 MQA 负责 KV 去冗余;两者叠加可将显存再降 1/3,并在 16 K-32 K context 下保持 2 × 以上 GPU 吞吐。
6 优缺点分析
6.1 优势
-
显存占用大幅降低,推理/训练可上更长序列或更大 batch。
-
内存带宽需求下降,带来 30-40 %的实际加速。
-
易于集成:只改 Attention Kernel,不动模型参数形状。
6.2 潜在不足
-
头间 Key/Value 共享可能略减精准度,在极端细粒度任务上需调参弥补。
-
目前主流实现只支持 Decoder-Only,Encoder-Decoder 尚需额外 kernel。
7 结语
在“长文本 + 轻量化”浪潮下,Multi-Query Attention 已成为大模型的必选项。只需一行配置即可吃到显存减半、速度翻倍的“硬件红利”,你还不赶快试试吗?
👍 点个赞 | ⭐ 收藏 | 💬 评论区聊聊 | 🔄 转发给同事——你的支持是我持续更新的最大动力!
参考文献
-
Shazeer N. “Multi-Query Attention with Key/Value Memory Reduction.” Google Research (2022).
-
Google AI Blog, “Efficient Transformer Inference via MQA.” 2022.
-
Dao T. et al., “FlashAttention.” NeurIPS 2023.
-
Falcon-40B 技术博客,TII 2023.
-
Hugging Face Blog, “Llama-3 with Multi-Query Attention.” 2024.
-
Fireworks AI, “Multi-Query Attention Is All You Need.” 2023.
-
清华 KEG,“ChatGLM2-6B 模型卡.” 2023.
-
TII Discussion #46,“Where is multiquery attention code?” 2023.
-
Patwary M. et al., “Efficient Inference with MQA in Megatron-LM.” NVIDIA Tech Report 2023.