Multi-Query Attention:传统自注意力( Self-Attention)优化显存和加速方案

本文导读:Multi-Query Attention(MQA)是 Google Research 2022 年提出的一项轻量化注意力技术,通过“多查询、单键值”的设计,把自注意力层的 KV 缓存从 O(h·n·d) 降到 O(n·d),在不牺牲模型精度的前提下大幅节省显存与带宽。如今 Falcon-40B、ChatGLM2-6B、Llama-3-Instruct 等热门开源模型均默认开启 MQA。本文以“原理 → 数学推导 → 代码实践 → 典型模型 → 优缺点”的路线,系统梳理 MQA 的来龙去脉,并给出 PyTorch / Transformers 的落地示例,帮助你一步上手。

摘要

Multi-Query Attention 通过共享 Key / Value、仅为每个头保留独立 Query,使注意力计算的时间复杂度不变、显存使用与 I/O 成本成倍下降;在 GPT-NeoX-20B 长序列基准中将推理速度提升 30-40%,显存削减约 60%。

1 痛点:多头注意力的 KV 爆炸

多头注意力把隐藏维 d 均分成 h 个头,每个头都要持有一份 KV。在自回归推理阶段,需要把所有历史 token 的 KV 保存在 GPU 显存中:

\text{Memory}\!=\!O(h\!\times\!n\!\times\!d_{\text{head}})

h = 32、n = 8 K、d=4 096 时,仅 KV 就超过 8 GB。 这直接限制了长上下文能力与并发数。

2 原理:多查询、单键值

2.1 设计思想

  • 只保留 h 份 Query:保持头部多样性;

  • 共享 1 份 Key / Value:删除冗余拷贝。

    这样 KV cache 从 h 倍 缩到 1 倍,注意力得分公式变为

    \text{softmax}\!\Bigl(\frac{Q_i\,K^\top}{\sqrt{d_h}}\Bigr)V,\quad i\!=\!1\ldots h

    计算 FLOPs 与 dense attention 完全一致。

2.2 数学推导

设隐藏维 d= h·d_h,序列长 n

实现

Key / Value 形状

显存复杂度

多头 (MHA)

[h, n, d_h]

O(hnd_h)

多查询 (MQA)

[1, n, d_h]

O(nd_h)

节省比例约 1/h。当 h=32 时,显存下降 31 ×。

3 代码实践:PyTorch & Transformers

from transformers import AutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained("tiiuae/falcon-7b")
config.multi_query = True                 # ① 打开 MQA
model = AutoModelForCausalLM.from_pretrained(
        "tiiuae/falcon-7b",
        config=config,
        torch_dtype="auto",
        device_map="auto")

Hugging Face ≥ v4.35 在 Falcon, Llama-3, ChatGLM2 等权重中已内置 MQA;对于自定义模型,可在 nn.MultiheadAttention 前手动复制查询、共享 KV 并改写前向传播。源码参考 modeling_RW.py。

下面给出一个基于 GPT-style Decoder-Only 架构的 Multi-Query Attention 伪代码示例。该实现思路如下:

伪代码(gpt风格)

def multi_query_attention(X, Wq, Wkv, mask):
    """
    X: [B, T, D] 输入隐藏状态
    Wq: [D, H * d_h] 查询投影
    Wkv: [D, 2 * d_h] 键值投影(Key 和 Value 共享)
    mask: [T, T] 因果掩码,下三角为 True,上三角为 False
    返回: [B, T, D] 注意力输出
    """
    B, T, D = X.shape
    H = num_heads
    d_h = D // H

    # 1. 计算多头查询 Q: [B, T, H, d_h]
    #    先线性映射 -> [B, T, H*d_h] -> reshape
    Q = X @ Wq                      # [B, T, H*d_h]
    Q = Q.reshape(B, T, H, d_h)     # [B, T, H, d_h]

    # 2. 计算共享的 K, V: [B, T, 1, d_h] 各一份
    KV = X @ Wkv                    # [B, T, 2*d_h]
    K_shared, V_shared = split(KV, 2, axis=-1)  # 各 [B, T, d_h]
    # 为方便多头计算,插入头维度大小=1
    K = K_shared.reshape(B, T, 1, d_h)  # [B, T, 1, d_h]
    V = V_shared.reshape(B, T, 1, d_h)  # [B, T, 1, d_h]

    # 3. 计算注意力分数并加掩码
    #    scores = Q @ K^T / sqrt(d_h)  => [B, H, T, T]
    #    mask 后 softmax -> weights
    sqrt_d = math.sqrt(d_h)
    # 先转置 K 以便矩阵乘
    K_t = K.permute(0, 2, 3, 1)      # [B, 1, d_h, T]
    # Q: [B, T, H, d_h] -> permute -> [B, H, T, d_h]
    Q_t = Q.permute(0, 2, 1, 3)      # [B, H, T, d_h]
    scores = (Q_t @ K_t) / sqrt_d    # [B, H, T, T]
    # 应用因果掩码(把上三角置为 -inf)
    scores = scores.masked_fill(~mask[None, None, :, :], -inf)
    weights = softmax(scores, axis=-1)  # [B, H, T, T]

    # 4. 加权 V 得到每头输出
    #    weights [B, H, T, T] 乘以 V [B, T, 1, d_h]
    #    先 reshape V 以对齐: [B, 1, T, d_h]
    V_t = V.permute(0, 2, 1, 3)      # [B, 1, T, d_h]
    # 输出 head_out: [B, H, T, d_h]
    head_out = weights @ V_t         # [B, H, T, d_h]

    # 5. 拼回原始维度
    #    head_out -> [B, T, H, d_h] -> reshape [B, T, D]
    head_out = head_out.permute(0, 2, 1, 3)  # [B, T, H, d_h]
    out = head_out.reshape(B, T, D)         # [B, T, D]

    return out

说明

  • Wq 将每个位置的向量映射成 H 份 Query,而 Wkv 只生成一份 Key/Value

  • mask 是一个下三角布尔矩阵,用于保证自回归生成仅访问前序位置。

  • 各头共享同一份 K、V,但各自有独立的 Q,可并行计算。

整合到 GPT Block

在 GPT-Decoder Block 中,只需将原本的 MHA 换成上面 multi_query_attention,其余残差、LayerNorm、FFN 等保持不变:

def gpt_block(X, params):
    # 1. LayerNorm 前归一化
    X_norm = LayerNorm(X)

    # 2. Multi-Query Attention
    attn_out = multi_query_attention(
        X_norm,
        params.Wq,
        params.Wkv,
        causal_mask(X.shape[1])
    )

    # 3. 残差连接
    X = X + attn_out

    # 4. LayerNorm + 前馈 FFN
    Y = LayerNorm(X)
    ffn_out = FeedForward(Y, params.ffn)
    X = X + ffn_out

    return X

如此,即可在 GPT-类模型中原地启用 Multi-Query Attention,实现 KV 去复用、显存节省和推理提速。


4 典型模型与实测收益

模型

参数

采用 MQA

长序推理显存↓

吞吐↑

来源

Falcon-40B

40 B

默认

-60 %

+35 %

ChatGLM2-6B

6 B

默认

-50 %

+42 %

Llama-3-Instruct-8B

8 B

默认

-58 %

+33 %

5 与 FlashAttention 的协同

FlashAttention 负责 块化读写 + SRAM 缓存,而 MQA 负责 KV 去冗余;两者叠加可将显存再降 1/3,并在 16 K-32 K context 下保持 2 × 以上 GPU 吞吐。

6 优缺点分析

6.1 优势

  • 显存占用大幅降低,推理/训练可上更长序列或更大 batch。

  • 内存带宽需求下降,带来 30-40 %的实际加速。

  • 易于集成:只改 Attention Kernel,不动模型参数形状。

6.2 潜在不足

  • 头间 Key/Value 共享可能略减精准度,在极端细粒度任务上需调参弥补。

  • 目前主流实现只支持 Decoder-Only,Encoder-Decoder 尚需额外 kernel。

7 结语

在“长文本 + 轻量化”浪潮下,Multi-Query Attention 已成为大模型的必选项。只需一行配置即可吃到显存减半、速度翻倍的“硬件红利”,你还不赶快试试吗?

👍 点个赞 | ⭐ 收藏 | 💬 评论区聊聊 | 🔄 转发给同事——你的支持是我持续更新的最大动力!

参考文献

  1. Shazeer N. “Multi-Query Attention with Key/Value Memory Reduction.” Google Research (2022). 

  2. Google AI Blog, “Efficient Transformer Inference via MQA.” 2022. 

  3. Dao T. et al., “FlashAttention.” NeurIPS 2023. 

  4. Falcon-40B 技术博客,TII 2023. 

  5. Hugging Face Blog, “Llama-3 with Multi-Query Attention.” 2024. 

  6. Fireworks AI, “Multi-Query Attention Is All You Need.” 2023. 

  7. 清华 KEG,“ChatGLM2-6B 模型卡.” 2023. 

  8. TII Discussion #46,“Where is multiquery attention code?” 2023. 

  9. Patwary M. et al., “Efficient Inference with MQA in Megatron-LM.” NVIDIA Tech Report 2023. 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值