又快又准、即插即用!清华提出8比特量化Attention,两倍加速不掉点

8191bd9d0c7ccf50c3c38a9283b255b0.gif

©作者 | 张金涛

单位 | 清华大学

大模型中,线性层的低比特量化(例如 INT8, INT4)已经逐步落地;对于注意力模块,目前几乎各个模型都还在用高精度(例如 FP16 或 FP32)的注意力运算进行训练和推理。然而,随着大型模型需要处理的序列长度不断增加,Attention(注意力运算)的时间开销逐渐成为网络优化的主要瓶颈。

为了提高注意力运算的效率,清华大学陈键飞团队提出了 8Bit 的 Attention(SageAttention)。实现了 2 倍以及 2.7 倍相比于 FlashAttention2 和 xformers 的即插即用的推理加速,且在视频、图像、文本生成等大模型上均没有端到端的精度损失

5b2176da14485b9f9bec639a83f42eb4.png

论文标题:

SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration

文章链接:

https://arxiv.org/abs/2410.02367

代码链接:

https://github.com/thu-ml/SageAttention

1376b3c1af19104c26fb59c57038129f.png

即插即用举例

SageAttention 可以一行代码轻松替换掉 torch 中当前最优的 Attention 接口(scaled_dot_product_attention),实现即插即用的推理加速。

68d0fc0d90bf328a0e36ca3676338bb7.png

具体来说,SageAttention 的使用非常方便,使用 pip install sageattention 后,只需要在模型的推理脚本前加入以下三行代码即可:

1623ec473939a63e5c4d70a810e66742.png

效果上,以开源视频生成模型 CogvideoX 为例,使用 SageAttention 可以端到端加速 35%,且生成的视频无损:

▲ 全精度 Attention

▲ SageAttention

接下来,将从背景与挑战,技术方案,以及实验效果介绍 SageAttention。

fe2617a3df3d4ecfa7374e9f9b717467.png

背景

随着大模型需要处理的序列长度越来越长(比如 Llama3.1 支持 128K 的序列长度),Attention 的速度优化变得越来越重要。下图展示了一个标准的 Transformer 模型中各运算随着序列长度变化的时间占比:

e34d6d0cb3aedaf720dd1dc8152fa12f.png

0203967f740bc45f5a7f2d3a85a8d4b1.png

挑战

为了方便指代注意力元算中包含的矩阵,我们先回顾一下注意力的计算公式:

1ddedc93356238045366ff67bbcd9ff9.png

将神经网络中各运算的数值类型从高比特量化至低比特是一种有效提升计算和访存效率的方法。然而,研究团队发现直接将注意力运算中的 Q, K, P, V 从 FP16 量化为 INT8 或者 FP8 后将会导致在几乎所有模型和任务上都会得到极差的结果,例如,在 Unidiffuser 文生图模型中,会得到一张完全模糊的图像;在 Llama2-7B 进行四选一选择题任务上得到 25.5% 的准确率。

e42d4d242c0d1c610ec0931708b7dffd.png

经过仔细分析后,研究团队发现主要是两个原因导致了量化注意力的不准确:

1. 大多视频、图像生成模型中,矩阵 K 表现出了极强的通道维度的异常值分布,直接使用 INT8 或者 FP8 数据类型对其进行量化会导致巨大的误差。

678010811ba55fbccadbe7a5699a6152.png

2. 在所有模型中,对矩阵 P, V 进行量化不能保证一个模型中所有层的精度。下表展示了对 P, V 量化后,Llama2-7B 和 Unidiffuser 模型所有层中,最差情况的层对应的量化注意力的准确度(该准确度为量化注意力相比全精度注意力的误差),可以发现不管对 P, V 矩阵进行何种 8Bit(INT8,E4M3,E5M2)量化,总有些层的准确率非常差,导致了端到端效果的下降。

d78fd1849387d5b18be40fa750a2361f.png

5f3a4a1cdff11474b50f6f248dd112ea.png

技术方案

为了解决上述的两个关键问题,研究团队提出了对应的解决办法。

对 K 进行平滑处理。SageAttention 采用了一个简单但非常实用的方法来消除矩阵 K 的异常值:K = K – mean (K) 其中 mean (K) 是沿着通道维度求平均值。

这个简单的做法不仅不会影响注意力计算的正确性 Softmax (QK^T) = Softmax (Q (K-mean (K))^T) ;且对整个 Attention 速度的影响只有 0.2%;同时还保证了量化后的注意力运算的精度:

6945f372490cfbb668a97c5dbaafb288.png

对 Q, K 进行分块 INT8 量化。对于矩阵 Q, K,SageAttention 采用了以 FlashAttention 的分块大小为粒度的 INT8 量化。这是因为:

1)对 Q, K 矩阵进行 INT8 量化相比于进行 FP8 量化,注意力的精度更高;

2)在一些常用卡上,比如 RTX4090,INT8 矩阵乘法(INT32 为累加器)的速度是 FP8(FP32 为累加器)的两倍。

对 P, V 采用 FP16 数据类型的矩阵乘法累加器。对于矩阵 P, V,SageAttention 采用了保留 P, V 为 FP16 的类型,但进行矩阵乘法时采用 FP16 数据类型的累加器。

这是因为:

1)PV 矩阵乘法的数值范围始终在 FP16 的表示范围内,且经过大量实验验证,FP16 作为累加器的数据类型不会带来任何精度损失(见下表);

2)在一些常用卡上,比如 RTX4090,以 FP16 为累加器数据类型的矩阵乘法的速度是 FP32 作为累加器的两倍。

c94362c8a8bc28b0bf24f92273769cdf.png

SageAttention 的流程图及算法如下所示:

e1c2219b94f31a8e566d12c13f318ce9.png

8a61113af13328f6b27c06234b1b1cc4.png

8391e2b69de7f0c31ce17f2d35e96522.png

实验效果

SageAttention 实现了底层的 GPU Kernel,在算子速度以及各个模型的端到端精度上都有十分不错的表现。

具体来说,算子速度相比于 FlashAttention2 和 xformers 有 2.1 以及 2.7 倍的加速。以下 4 张图展示了在 RTX4090 上,不同的序列长度下 SageAttention 的各种 Kernel 与其他方法的速度比较。

dc4fcdc70d44891f3b372be8f82316d5.png

b48aeeab160d86a27ba946cd6981e5d4.png

以下 4 张图展示了在 RTX3090 上,不同的序列长度下 SageAttention 的各种 Kernel 与其他方法的速度比较。

97812dec0b0873dbd6dbe8119984d9ed.png

28d9610f6d0ab0fb9e60ea0b40d3d1c6.png

下表展示了在 RTX4090 上,各模型中的注意力模块中 SageAttention 相比于使用模型原始的注意力的加速比。

5f218c77aa1df4530d4de3c2233f5135.png

真实任务的精度上,下表展示了 SageAttention 在视频、图像、文本生成等大模型上均没有端到端的精度损失:

be27c7daeb7f22e3b0d8a2b952d50224.png

更多阅读

d3d098d3c1b8f476ebdca7ad75d70b58.png

05b1558308df6a7d8f9287b3785654f9.png

f868a01abe0ab476ae4f58a80b3d7b29.png

fb48eff77d69444a33fd5344cdb0d099.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

313d559d53b65d4ed6090c65bbe673fa.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

·

8f72b68717b344e44efd00430ddcb317.jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值