TL;DR. 本文介绍来自华为诺亚方舟实验室、清华大学和香港中文大学联合在大语言模型量化上的最新工作 FlatQuant (Fast and Learnable Affine Transformation)。FlatQuant 通过为每个线性层适配轻量的可学习的仿射变换,有效平滑 LLM 离群值,得到更加平坦的权重和激活值分布,有效降低量化损失。相比此前的量化方法 [1][2],本方法首次在 LLaMA-3-70B 上达到 W4A4 <1% 的精度损失,并可带来最高 2.3x prefill 和 1.7x decoding 加速比。论文代码已开源,欢迎大家关注和使用~
论文链接:FlatQuant: Flatness Matters for LLM Quantization
模型量化是大语言模型 (LLM) 推理加速的常用技术,可以通过将权重和激活值同时压缩到低比特来有效降低访存开销并利用峰值算力更高的 INT4/8 Tensor Core 完成矩阵运算,从而带来实际的推理加速比。然而,目前的 W4A4 量化模型相比全精度模型还存在着较大的量化损失,难以在实际应用中使用,也就难以利用峰值算力最高的 INT4 Tensor Core 加速 LLM 的实际推理部署。我们发现,量化前权重和激活值分布的平坦度 (flatness) 是影响 LLM 量化误差的关键因素,直观来看,分布越平坦,离群值就越少,量化时的精度也就越高。已有方法大多使用 pre-quantization transformations,通过在量化前对权重和激活值做等价变换得到更平坦的分布来降低量化误差,常用的变换主要有 per-channel scaling [1] 和 Hadamard 变换 [2]。然而,我们发现这些变换并不是最优的,为此我们提出 FlatQuant (Fast and Learnable Affine Transformation),为每个线性层学习一个最优的仿射变换来有效缓解权重和激活值上的离群值,从而得到平坦的权重和激活值分布,有效提升了量化精度。此外,针对推理中的在线变换,我们进行了算子融合进一步降低访存开销,使得在线变换仅带来极小的推理开销。实验表明,FlatQuant 在 W4A4 的设置下极大地减少了量化模型的精度损失,甚至在部分模型上达到了接近无损的效果 (e.g. LLaMA-3-70B),轻量的在线变换也使得 FlatQuant 能达到 2.3x 的 prefill 和 1.7x 的 decoding 加速比。我们希望 FlatQuant 能进一步推动 W4A4 LLM 的实际部署,从而更加有效地降低 LLM 的推理成本。
The Flatness for Quantization
LLM 的权重和激活值上存在较多的离群值,特别是激活值上常常存在离群值通道 (outlier channels),导致 LLM 难以量化。目前针对 LLM WA 量化的方法大多在量化前对权重和激活值做等价变换来用其他通道吸收离群值,从而得到更加平坦的分布以降低量化损失,例如 per-channel scaling [1] 对应的等价变换为 Y = ( X diag ( c ) − 1 ) ⋅ ( d i a g ( c ) W ⊤ ) {\mathbf{Y}}=(\mathbf{X}\text{diag}(\boldsymbol{c})^{-1})\cdot(\mathrm{diag}(\boldsymbol{c})\mathbf{W}^\top) Y=(Xdiag(c)−1)⋅(diag(c)W⊤),通过 scaling 将激活值上的离群值转移到权重的相同通道上,使得激活值分布更加平坦;Hadamard 变换 Y = X W ⊤ = ( X H ) ( H ⊤ W ⊤ ) \mathbf{Y}=\mathbf{X}\mathbf{W}^{\top}=(\mathbf{X}\mathbf{H})(\mathbf{H}^{\top}\mathbf{W}^{\top}) Y=XW⊤=(XH)(H⊤W⊤) 则是通过给权重和激活值同时做Hadamard变换来将离群值重新分配到权重/激活值的其他通道上;在图 1 中,我们画出了 LLM 的不同权重和激活值在变换前后的分布情况,理想情况下,我们希望能利用所有通道吸收离群值,使得变换后的分布呈现一条平坦的水平线。但如图 1 所示,我们发现已有的等价变换得到的分布仍然可能是不平坦的:per-channel scaling 中,离群值仍然被限制在了权重和激活值的相同通道上,非离群值通道得不到有效利用,因此不管是权重还是激活值,变换后的分布都非常陡峭,呈现出非常明显的离群值通道;Hadamard 变换对所有权重和激活值都施加相同的变换,而不同层的权重和激活值分布是不同的,这意味着 Hadamard 变换并不是对于每个层的最优解,例如图 1(a)(b) 中,LLaMA-3-8B 的权重和激活值经过 Hadamard 变换后仍然比较陡峭,特别是激活值上的离群值无法得到有效平滑,此外,Hadamard 变换作为一种正交变换不会改变向量的模长,而 LLM 激活值上大量的离群值会导致激活值模长显著大于权重,这导致正交变换后的激活值量化难度也会显著高于权重,无法像 per-channel scaling 一样灵活地平衡权重和激活值上的量化难度。相比之下,FlatQuant 通过给每一层针对性地学习仿射变换,不仅可以得到平坦的分布,还可以自适应地平衡权重和激活值的量化难度。
在图 2 中,我们画出了不同变换后 LLM 的量化损失平面,可以发现,per-channel scaling 和 Hadamard 变换都无法很好处理具有 massive outlier[3] 的关键词元 (pivot token),导致在首词元上具有非常大的量化误差,已有研究表明关键词元上的量化误差会比较严重地影响模型的量化精度[4]。相比之下,FlatQuant 则可以显著降低关键词元上的量化损失,并有效抑制量化误差的逐层传播,带来更加平坦的量化损失平面。
方法概述
FlatQuant 通过轻量的仿射变换平滑权重和激活值上的离群值,需要为每个线性层学习最优的仿射变换
P
∗
\mathbf P^*
P∗:
P
∗
=
arg
min
P
∥
Y
−
Q
(
X
P
)
Q
(
P
−
1
W
⊤
)
∥
F
2
\mathbf{P}^*=\arg \min _{\mathbf{P}}\left\|\mathbf{Y}-\mathcal{Q}(\mathbf{X P}) \mathcal{Q}\left(\mathbf{P}^{-1} \mathbf{W}^{\top}\right)\right\|_F^2
P∗=argPmin
Y−Q(XP)Q(P−1W⊤)
F2学到
P
\mathbf P
P 后,变换
P
−
1
W
⊤
{\mathbf{P}}^{-1} \mathbf{W}^{\top}
P−1W⊤ 可以融到权重中不会带来额外推理开销,但
X
P
\mathbf{X P}
XP 必须作为在线变换,这会使得线性层的存储和计算开销翻倍,这显然是不现实的。
Kronecker Decomposition. 为了解决上述问题,我们对 P \mathbf P P 使用 Kronecker 分解 P = P 1 ⊗ P 2 \mathbf{P}=\mathbf{P}_1 \otimes \mathbf{P}_2 P=P1⊗P2,其中 P 1 ∈ R n 1 × n 1 , P 2 ∈ R n 2 × n 2 \mathbf{P}_1 \in \mathbb{R}^{n_1 \times n_1}, \mathbf{P}_2 \in \mathbb{R}^{n_2 \times n_2} P1∈Rn1×n1,P2∈Rn2×n2,选取 n = n 1 n 2 n=n_1n_2 n=n1n2 并且 n 1 , n 2 n_1,n_2 n1,n2 尽可能接近即可,例如 n = 4096 n=4096 n=4096 时,有 n 1 = n 2 = 64 n_1=n_2=64 n1=n2=64。这样在线变换 X P \mathbf{X P} XP 可以表示为 P 1 ⊤ × 1 X ~ × 2 P 2 \mathbf{P}_1^{\top} \times_1 \tilde{\mathbf{X}} \times_2 \mathbf{P}_2 P1⊤×1X~×2P2,相当于把 X ∈ R k × n \mathbf X\in\mathbb R^{k\times n} X∈Rk×n reshape 为 X ~ ∈ R k × n 1 × n 2 \tilde {\mathbf X}\in\mathbb R^{k\times n_1\times n_2} X~∈Rk×n1×n2 然后左右分别乘上 P 1 \mathbf{P}_1 P1 和 P 2 \mathbf{P}_2 P2 两个小矩阵,相比直接使用稠密变换矩阵可以将额外的内存和计算开销分别降低至原来的 2 / n 2/n 2/n 和 2 / n 2/\sqrt n 2/n。
Per-channel Scaling. Kronecker 分解本质上还是对 P \mathbf P P 的 rank-1 近似,我们进一步使用 learnable per-channel scaling 提升 Kronecker 分解的表征能力。per-channel scaling 可以融到前序的 LN/线性层中不会带来额外推理开销。
Learnable Clipping Thresholds. 我们对变换后的权重和激活值进一步采用了 learnable clipping 来更好地消除离群值。
优化过程. 损失函数采用 Layer-wise MSE loss:
min
Θ
∥
F
l
(
X
)
−
F
^
l
(
X
;
Θ
)
∥
F
2
\min_{\boldsymbol{\Theta}}\lVert\mathcal{F}_l(\mathbf{X})-\hat{\mathcal{F}}_l(\mathbf{X};\Theta)\rVert_F^2
Θmin∥Fl(X)−F^l(X;Θ)∥F2其中,
F
l
\mathcal{F}_l
Fl 和
F
^
l
\hat{\mathcal{F}}_l
F^l 分别为 FP16 和量化后的 Transformer block,
Θ
\Theta
Θ 包括仿射变换、per-channel scaling 以及 clipping 参数,其中仿射变换初始化为随机仿射变换。整个训练过程非常轻量,7B 模型仅需单卡大约 1h 即可完成。
模型架构. 如图 3 所示,FlatQuant 在单个 Transformer 内会引入 5 种不同的在线变换,对于 LLaMA-2-7B,这些在线变换在序列长度 2K 时的 FLOPs 仅为 FP16 模型的 2.61%,对在线变换中两个小矩阵乘以及量化操作的算子融合还可以帮助进一步降低 FlatQuant 的额外推理开销。另外注意到,在 QuaRot[2] 和 SpinQuant[5] 中,为了降低在线推理开销,MHA / MLP 输入处的正交变换会被融合到前序线性层里,但由于残差连接的限制,不同 Transformer block 中的 MHA / MLP 都必须共享输入处的正交变换,这不仅限制了变换的灵活性,还使得在优化变换矩阵时必须采用端到端优化,需要较大的训练开销;相比之下,FlatQuant 不仅可以对每个线性层都学得最适配的等价变换,还可以逐层优化,仅需单卡即可完成对 70B 模型的量化。
实验结果
量化设置. 实验中,我们保持了与 QuaRot[2] 相同的量化设置,权重和激活值采用 per-channel 和 per-token 对称量化,KV cache 量化采用 group-wise 非对称量化 (g128),校准集为来自 WkiText-2 数据集的 128 条样本。
量化精度
我们测试了 W4A4 下量化模型的 PPL 和 QA 任务上的精度结果,从表 1 和表 2 中可以看到,FlatQuant 在使用 RTN 作为 weight quantizer 时精度就已经能比较明显地超过 QuaRot 和 SpinQuant 使用 GPTQ 的效果,对于较大的 13B/70B 模型,QA 精度损失均在 1% 左右,更小的 7B/8B 模型的精度损失也维持在了 2% 左右;FlatQuant 对于更难量化的 LLaMA-3 模型提升尤为明显, 例如 LLaMA-3-70B 的 QA 任务上 FlatQuant 相比 SpinQuant 有超过 7% 的精度提升,同时与全精度模型的精度差距保持在 1% 以内。
端到端加速比
我们在 RTX3090 上测试了 FlatQuant 的 prefill/decoding 端到端加速比。如图 4 所示,FlatQuant 最高能带来 2.30x 的 prefill 和 1.76x 的 decoding 加速比,推理速度超过了 QuaRot,相比 INT4 也仅有极小的加速比损失。
更多实验
(1) 消融实验. 从表 3 中可以看到,在 RTN 量化的基础上加入 LT (Learnable Transformation) 就已经能极大地提升量化模型精度,进一步加入 PS (Per-channel Scaling) 和 LCT (Learnable Clipping Thresholds) 还能进一步提升模型精度。
(2) 权重量化. FlatQuant 在权重量化上也能与 SOTA 的 uniform 量化方法达到相当的精度。
(3) Train One and Get More. FlatQuant 中 W4A4 量化设置下学到的变换矩阵可以直接用在其他量化设置下,这使得我们能更加便利地在不同量化设置下使用 FlatQuant。
参考文献
[1] Xiao, Guangxuan, et al. “Smoothquant: Accurate and efficient post-training quantization for large language models.” International Conference on Machine Learning. PMLR, 2023.
[2] Ashkboos, Saleh, et al. “Quarot: Outlier-free 4-bit inference in rotated llms.” arXiv preprint arXiv:2404.00456 (2024).
[3] Sun, Mingjie, et al. “Massive Activations in Large Language Models.” arXiv preprint arXiv:2402.17762 (2024).
[4] Liu, Ruikang, et al. "IntactKV: Improving Large Language Model Quantization by Keeping Pivot Tokens Intact."arXiv preprint arXiv:2403.01241(2024).
[5] Liu, Zechun, et al. "SpinQuant–LLM quantization with learned rotations."arXiv preprint arXiv:2405.16406(2024).