Transformer升级之路:多头潜在注意力机制(MLA)究竟好在哪里?

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 科学空间

研究方向 | NLP、神经网络

自从 DeepSeek 爆火后,它所提的 Attention 变体 MLA(Multi-head Latent Attention)也愈发受到关注。

MLA 通过巧妙的设计实现了 MHA 与 MQA 的自由切换,使得模型可以根据训练和推理的不同特性(Compute-Bound or Memory-Bound)选择最佳的形式,尽可能地达到效率最大化。 

诚然,MLA 很有效,但也有观点认为它不够优雅,所以寻找 MLA 替代品的努力一直存在,包括我们也有在尝试。

然而,经过一段时间的实验,我们发现很多 KV Cache 相同甚至更大的 Attention 变体,最终效果都不如 MLA。这不得不让我们开始反思:MLA 的出色表现背后的关键原因究竟是什么? 

接下来,本文将详细介绍笔者围绕这一问题的思考过程以及相关实验结果。

观察

MLA 提出自 DeepSeek-V2 [1],本文假设读者已经熟悉 MLA,至少了解之前的文章缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA所介绍的内容,因此 MLA 自身的细节将不会过多展开。 

MLA 的主要特点如下:

1. MLA 在训练阶段是一个 qk_head_dims=(128+64)、v_head_dims=128 的 MHA;

2. MLA 在解码阶段是一个 qk_head_dims=(512+64)、v_head_dims=512、KV-Shared 的 MQA; 

3. MLA 的 [qc, qr]、[kc, kr] 拼接,可以理解为一种Partial RoPE

猜测

MHA、GQA 常用的 head_dims 是 128,而对于 MLA 来说,不管是从训练看的 128+64,还是从推理看的 512+64,都要大于 128,再结合《突破瓶颈,打造更强大的 Transformer》[2] 的经验,我们有: 

猜测 1:增大 head_dims 是 MLA 好的关键之一。 

另外,KV-Shared 这个特性,可以在同等 KV Cache 大小下,增大 GQA 的 head_dims 或者 num_groups,所以有: 

猜测 2:KV-Shared 是 MLA 好的关键之一。 

最后,此前有一些理论和实验显示 Partial RoPE 可能会对效果有正面帮助(参考Transformer升级之路:RoPE的底数选择原则),所以有:

猜测 3:Partial RoPE 是 MLA 好的关键之一。

实验

现在我们通过实验逐一检验以上猜测。

设置

所有实验公共部分的超参数如下:

1. 类似 LLAMA3 的 Dense 模型;

2. hidden_size=2048,num_layers=12,num_heads=16; 

3. 优化器是Muon,Attention 部分 per head 更新;

4. 训练长度为 4096,总 tokens 数为 16B,总训练步数为 16k;

5. 所有实验都是只改变 Attention,所以参数量不会严格对齐。

Part I

MLA 的 KV Cache 大小是 512+64,约等于 GQA2-128(第一个数字是 num_groups,第二个数字是 head_dims),所以对比的 baseline 为 GQA2-128 和 GQA1-256。

为了验证 Partial RoPE,我们增加了 GQA1-256-PR,具体做法是将 Q、K 的 256 dims 分成 192+64 两部分,在 64 上加 RoPE,192 不加。 

结果如下:

即:

初步验证了增大 head_dims 和 Partial RoPE 的作用。这样看来,MLA 的设计中,RoPE 和 NoPE 拼接这部分看似无奈的设计,极有可能是它效果优异的关键原因!原论文声称 MLA 甚至优于 MHA,大概率也是因为所对比的 MHA 的 head_dims 只有 128。

Part II

为了进一步验证增大 head_dims 的作用,我们另外跑了 MHA、GQA2-192、MLA-256 三个实验,MHA 是 head_dims=128 的常规 MHA,GQA2-192 是直接增大 GQA2 的 head_dims 到 192,MLA-256 是将 MLA 的 128+64 提升到 192+64,对照如下:

可以看到,MHA 总参数量更多,KV Cache 更是 7 倍于 MLA,但 Loss 才堪堪追平 MLA,这跟 DeepSeek-V2 里边的结论接近。

此外,GQA2-192 优于 GQA2-128,但不如 GQA1-256;MLA 的 head_dims 升到 (192+64) 后,相比 (128+64) 也还能进一步提升效果。这些现象都表明,增加 head_dims 远比增加 num_groups 更有效。

Part III

接下来我们验证 KV-Shared,即 K、V 共享全部或大部分 dims。这里我们主要考虑的替代品是 head_dims 不超过 256 的 GQA,并且控制 KV Cache 的总大小跟 MLA 接近,所以当 KV-Shared 时,我们可以至多可以考虑 GQA2-256。

由于 KV-Shared 跟 RoPE 不完全兼容,参考 MLA 的做法,我们将 256 分成 192+64 两部分,其中:

1. 192 部分不加 RoPE,在 K、V 间共享; 

2. 64 部分加 RoPE,只用于 K;

3. V 另外再投影 64 dims,concat 到共享的 192 dims 上去。

这样一来,K、V 的 head_dims 都是 256,KV Cache 总大小是 (192+64+64)*2=640,略大于 MLA 的 512+64=576,这个版本我们简记为“GQA2-(192+64)-S1”,其实“S1”是“Shared-1”的缩写。

Part IV

另外一种 KV-Shared 的方案是:

1. 192 部分不加 RoPE,在 K、V 间共享;

2. 64 部分加 RoPE,同样在 K、V 间共享;

3. 做 Attention,由于 V 带 RoPE,此时是绝对位置编码效果;

4. 为了保证相对位置编码,将输出分成 192+64 两部分,64 部分再加一次逆向 RoPE。

这种做法是 K、V 完全共享,KV Cache 大小是 (192+64)*2=512,略小于 MLA。这个版本我们称为“GQA2-(192+64)-S2”,“S2”是“Shared-2”的缩写,背后的原理是笔者新提出的 VO-RoPE,参考Transformer升级之路:第二类旋转位置编码 [3]。

Part V

另外,根据同样思路补了几个 GQA4 和 GQA1 的实验。所有实验结果汇总如下:

这里“GQA1-(512+64)-S3”是按照 MLA 的推理形式实现的 MQA,形式介乎 S1 与 S2 之间,它的主要特点是 head_dims 大。

结果解读:

1. KV-Shared 的 GQA 自带 Partial RoPE;

2. KV-Shared 的 GQA2-256,也能超过 MLA;

3. VO-RoPE 的引入,似乎有利于效果(S1 ≲ S2); 

4. 同等 KV Cache 下,head_dims 越大越好;

5. GQA2-(192+64)-S2 略微超过 GQA1-256-PR; 

6. GQA4-(128+64)-S2 的 KV Cache 最大,但效果不是最优,再次表明 head_dims 更关键。

关于 KV-Shared,还有两点观察:

1. 训练过程中,GQA1-256-PR 前期是明显领先 GQA2-(192+64)-S2,但后期被追平甚至略微反先,猜测 GQA1-256-PR 可能有后劲不足的嫌疑;

2. 如果没有 KV-Shared,GQA 顶多是 GQA1-256,也就是说 head_dims 顶天了 256,但有 KV-Shared 的话,GQA 可以做到 GQA1-512-S,单纯从 head_dims 看,KV-Shared 天花板更高。

Part VI

由于没有严格对齐参数量,可能读者会有“到底是增加参数量还是增加 head_dims 更本质”的疑虑,所以这里补充几个对齐参数量的实验。

这里考虑的对齐参数量的方式有三种:

1. double-heads:以“GQA2-128 vs GQA1-256”为例,将 GQA2-128 的 num_heads 翻倍,可以让 GQA2-128 的参数量跟 GQA1-256 相同; 

2. 缩减 MLP:缩小 MLP(SwiGLU)的 intermediate_size,也可以使得 GQA1-256 的参数量跟 GQA2-128 大致相同;

3. Q&O LoRA:GQA 的主要参数量来自 Query 和 Output 的投影矩阵,对这两个矩阵改用 LoRA,也可以降低 GQA1-256 的参数量。

实验结果如下:

结果主要分三块:

1. heads 翻倍相比 head_dims 翻倍,loss 稳定差 0.003 左右; 

2. 缩小 MLP 比 head_dims 减半,loss 稳定优 0.004 左右;

3. Q&O LoRA 性能损失最小,可以实现 head_dims 翻倍但参数量不增,且 loss 明显降。

结论:如果从增加参数量角度看,增大 head_dims 可能是效果增益较大的方向,配合 Q&O LoRA 可以实现参数量几乎不增,但收益仍相当。

小结

初步结论是:

1. 增大 head_dims 收益最大;

2. Partial RoPE 对 Loss 也有一定帮助;

3. KV-Shared 应该也有一定作用。

这样看来,此前我们一直在 head_dims=128 下找 MLA 的替代品,感觉是起点就先天不足了,难怪一直比不上 MLA。要想追平 MLA,head_dims 应该要 192 起步了,并辅以 Partial RoPE。至于 KV-Shared,也可能有用,但应该还需要更大规模的验证。

意义

其实这里边的意义,就看我们换掉 MLA 的决心有多强。 

假设 GQA2-(192+64)-S2 可以替代 MLA,但 MLA 也可以升到 256,目前看来 GQA2-(192+64)-S2 比不上 MLA-256 。那么换掉 MLA 的唯二好处是:

1. 结构更简单,可以方便加 QK-Norm;

2. 解码阶段的 head_dims 由 512+64 变成了 256,同时 num_groups 变为 2,可以 TP。

参考文献

[1] https://papers.cool/arxiv/2405.04434

[2] https://kexue.fm/archives/7325

[3] https://kexue.fm/archives/10862

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

·

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值