手撕大模型 | MQA 和 GQA 原理解析

一、前言

大模型(参数规模通常数十亿至万亿级)在处理复杂任务时面临三大核心问题:

  1. 显式关联的局限性:传统 Multi-head Attention 依赖输入数据的显式特征(如文本中的词向量、图像中的像素特征)计算注意力,难以捕捉深层语义(如 “同义词替换”“上下文隐喻”)或抽象结构(如 “逻辑推理链”)。
  2. 数据效率与泛化瓶颈:大模型训练需海量数据,但在低资源语言、专业领域(如医学、法律)中,显式关联数据稀缺,导致模型泛化能力骤降。
  3. 多模态融合难点:跨模态任务(如图文生成、视频理解)中,不同模态的特征空间差异大(如文本的离散符号 vs 图像的连续像素),显式关联(如 “图像中的猫” 与文本 “猫”)之外的隐式关联(如 “图像风格” 与 “文本情感”)难以建模。

在前面的文章中,笔者已经讲解了 LLM 推理的关键技术-KV Cache(【手撕大模型】KVCache 原理及代码解析),但是随着大模型功能的不断强化,其容量也在增加,当前的 KVCache 技术已经不能满足发展需要了,所以,各种针对于 KVCache 优化的技术应时而生。

二、优化 KV cache 的方法

参考 https://zhuanlan.zhihu.com/p/16730036197

当前,业界针对 KV Cache 的优化方法可以总结为有四类:

  1. 共享 KV:多个 Head 共享使用 1 组 KV,将原来每个 Head 一个 KV,变成 1 组 Head 一个 KV,来压缩 KV 的存储。代表方法:GQA,MQA 等。
  2. 窗口 KV:针对长序列控制一个计算 KV 的窗口,KV cache 只保存窗口内的结果(窗口长度远小于序列长度),超出窗口的 KV 会被丢弃,通过这种方法能减少 KV 的存储,当然也会损失一定的长文推理效果。代表方法:Longformer 等。
  3. 量化压缩:基于量化的方法,通过更低的 Bit 位来保存 KV,将单 KV 结果进一步压缩,代表方法:INT8/INT4 等。
  4. 计算优化:通过优化计算过程,减少访存换入换出的次数,让更多计算在片上存储 SRAM 进行,以提升推理性能,代表方法:flashAttention 等。

共享 KV 主要有两种方法,MQA 和 GQA 都是 Google 提出的,详见: MQA(2019)GQA(2023)

三、MQA &

MQA(多查询注意力)和 GQA(分组查询注意力)作为自注意力机制的优化版本,主要作用是加快推理进程、减少内存占用,同时努力维持模型原有的性能表现。

以 Llama 7B 模型为例,其隐藏层维度为 4096,这意味着每个 K、V 向量都包含 4096 个数据。若采用半精度浮点(float16)格式存储,单个 Transformer 模块中,单序列的 K、V 缓存空间就达到 4096×2×2=16KB。由于 Llama 2 包含 32 个 Transformer 模块,单个序列在整个模型中的缓存需求便为 16KB×32=512KB。

那么多序列的情况呢?倘若句子长度为 1024,缓存空间就会增至 512MB。目前英伟达性能顶尖的 H100 显卡,其 SRAM 缓存约为 50MB,A100 则为 40MB,显然难以满足需求。尽管可将数据存于 GPU 显存(DRAM),但会对性能产生影响。7B 规模的模型已是如此,175B 规模的模型面临的问题更严峻。

解决这一问题的思路可从硬件与软件两方面展开:

  • 硬件层面,可借助 HBM(高带宽内存)提高数据读取速度;或摆脱冯・诺依曼架构的束缚,改变计算单元从内存读取数据的方式,转而以存储为核心,构建计算与存储一体化的 “存内计算” 模式,例如采用 “忆阻器” 技术。
  • 软件层面则通过算法优化来解决,Llama 2 所采用的 GQA(分组查询注意力)便是其中一种方案。

下面将通过图示来展示 MQA、GQA 与传统 MHA(多头注意力)的差异:

img

多头注意力机制(MHA)就是多个头各自拥有自己的 Q,K,V 来算各自的 Self-Attention,而 MQA(Multi Query Attention)就是 Q 依然保持多头,但是 K,V 只有一个,所有多头的 Q 共享一个 K,V ,这样做虽然能最大程度减少 KV Cache 所需的缓存空间,但是可想而知参数的减少意味着精度的下降,所以为了在精度和计算之间做一个 trade-off,GQA (Group Query Attention)孕育而生,即 Q 依然是多头,但是分组共享 K,V,即减少了 K,V 缓存所需的缓存空间,也暴露了大部分参数不至于精度损失严重。

四、MQA

MQA 的思路比较简单,详见上图,每一层的所有 Head,共享同一个 KV 来计算 Attention。相对于 MHA 的单个 Token 需要保存的 KV 数减少了 n_h 倍(head 数量),即每一层共享使用一个 Q 向量和一个 V 向量。

使用 MQA 的模型包括 PaLMStarCoderGemini 等。很明显,MQA 直接将 KV Cache 减少到了原来的 1/n_h,这是非常可观的,单从节省显存角度看已经是天花板了。

效果方面,目前看来大部分任务的损失都比较有限,且 MQA 的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到 MQA 由于共享了 K、V,将会导致 Attention 的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大 FFN/GLU 的规模,这也能弥补一部分效果损失。

五、GQA

GQA 是平衡了 MQA 和 MHA 的一种折中的方法,不是每个 Head 一个 KV,也不是所有 Head 共享一个 KV,而是对所有 Head 分组,比如分组数为 g ,那么每组: n_h/g 个 Head 共享一个 KV。当 g=1 时,GQA 就等价于 MQA,当 g=n_h 时, GQA 就等价于 MHA。

为了方便更清晰的理解 GQA 和 MQA ,使用一个 Token 计算 KV 过程来进行演示:

img

总结下单 token 计算下,几种方法 KV Cache 的存储量(模型层数:l,每层 Head 数量:n_h )

img

六、参考链接

https://zhuanlan.zhihu.com/p/16730036197

54376)]

六、参考链接

https://zhuanlan.zhihu.com/p/16730036197

https://spaces.ac.cn/archives/10091

### MQA 源码解析与编码实现详解 #### 1. 多查询注意力机制 (Multi-Query Attention, MQA) 的核心概念 多查询注意力是一种优化版本的自注意力机制,旨在减少计算开销的同时保持模型性能。在标准的多头注意力(MHA)中,每个头部都有独立的键(Key)、值(Value)查询(Query)。然而,在MQA中,所有的头部共享一组公共的键值投影矩阵[^3]。 这种设计显著减少了参数数量以及内存占用,因为只需要维护少量的键值权重矩阵即可支持多个查询头的操作。具体来说,在MQA中,`key` `value` 被压缩至单个维度大小(通常是head_dim),而 `query` 则保留完整的维度以捕获更丰富的特征表示。 #### 2. MQA 的数学描述 假设输入张量 \( X \in R^{T\times d_{model}} \),其中 T 表示序列长度,\(d_{model}\) 是隐藏层尺寸,则可以定义如下: \[ Q = XW_Q,\ K = XW_K,\ V = XW_V, \] 对于MQA而言, - 查询矩阵 Q 的形状为 \( [B,T,H,d_k] \),即 batch size × 序列长度 × 头数 × 单头维度; - 键矩阵 K 值矩阵 V 的形状则简化为 \( [B,T,d_k'] \),这里 \(d_k'\) 远小于原始的 \(H*d_k\) 总维度。 最终通过缩放点积操作得到注意力分布并加权求获得输出 O: ```python import torch from einops import rearrange def mqa_attention(query, key, value, num_heads=8): """ 实现简单的 Multi-Query Attention. 参数: query: shape [batch_size, seq_len, hidden_dim] key: shape [batch_size, seq_len, reduced_hidden_dim] value: shape [batch_size, seq_len, reduced_hidden_dim] num_heads: 注意力头的数量 返回: output: 输出张量,shape [batch_size, seq_len, hidden_dim] """ batch_size, seq_len, hidden_dim = query.shape _, _, reduced_hidden_dim = key.shape # 将 Query 投影成多头形式 q = rearrange(query, 'b s (h d) -> b h s d', h=num_heads) # Key/Value 不需要拆分成多头,直接广播匹配 k = key.unsqueeze(1).expand(-1, num_heads, -1, -1) v = value.unsqueeze(1).expand(-1, num_heads, -1, -1) # 计算注意力分数 scores = torch.einsum('bhqd,bhkd->bhqk', q, k) / (reduced_hidden_dim ** 0.5) # 归一化概率分布 attn_weights = torch.softmax(scores, dim=-1) # 加权平均 Value 向量 context_vectors = torch.einsum('bhqk,bhvd->bhqv', attn_weights, v) # Reshape 回原形状 output = rearrange(context_vectors, 'b h s d -> b s (h d)') return output ``` 上述代码展示了如何利用 PyTorch 构建一个基础版的 MQA 层次结构。注意这里的 `einsum()` 函数用于高效地完成批量矩阵乘法运算;同时借助 `einops.rearrange()` 方法方便地调整数据布局以便于后续处理。 #### 3. 关键差异对比分析 | 特性 | MHA | MQA | |--------------|------------------------------|-------------------------------| | Query Head | 每个头单独学习 | 所有头共享相同的一套KV | | Parameter Count | 较高 | 显著降低 | | Memory Usage || 更低 | 由此可见,尽管两者都属于 Transformer 家族成员之一,但在实际部署过程中如果资源受限或者追求极致效率的话,那么采用 MQA 方案无疑会更加合适一些。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值