FlashDecoding++

论文地址:https://arxiv.org/pdf/2311.01282.pdf

背景

LLM推理中一个常用的优化技巧是 KV Cache,通过在计算时缓存 KV 的结果,来节约大部分的运算时间;所以 LLM 推理可以分为两个部分,分别是预测第一个 token 的 prefill 阶段和使用 KV Cache 的 decode 阶段

推理两阶段示意图描述

Flash Decoding 为了解决 Flash attention 在小 batch 推理时 GPU 利用率低的问题,引入了 key/value 的并行化,提出分块处理 attention output 后合并的思路。虽然这样增加了总体的计算量,但是在 GPU 利用率不高时提速明显。

本文参考 Flash Decoding 的思路,在此基础上提出了 Flash Decoding++ ,主要解决了下面三个问题:

  1. 分块处理过程中同步更新 softmax 耗时过高
  2. Decode阶段输入只有一个 token ,存在大量的 flat GEMM 操作,GPU利用率低
  3. 静态的 kernel 无法应对动态的输入与硬件配置,需要找到一个方法平衡 内存 memory-bounded 和 计算 compute-bounded

论文效果

在 llama 7B 上测试了不同 GPU 的推理速度(主要关注 A100 ),相比于 TensorRT-LLM 和其他的推理框架有较大提升。从测试结果上来看比我目前使用的 vllm 要快 20%

在这里插入图片描述

具体方法

Partial softmax with unified max value

Flash Decoding 合并 softmax 结果的操作如下:
m ( x ) = m a x ( m ( x 1 ) , m ( x 2 ) ) f ( x 1 ) = e m ( x 1 ) − m ( x ) f ( x 1 ) f ( x 2 ) = e m ( x 2 ) − m ( x ) f ( x 2 ) l ( x ) = f ( x 1 ) + f ( x 2 ) s o f t m a x ( x 1 , x 2 ) = ( f ( x 1 ) l ( x ) , f ( x 2 ) l ( x ) ) \begin{gather} m(x) = max(m(x_1),m(x_2)) \\ f(x_1) = e^{m(x_1)-m(x)}f(x_1) \\ f(x_2) = e^{m(x_2)-m(x)}f(x_2) \\ l(x) = f(x_1) + f(x_2) \\ softmax(x_1,x_2) = (\frac{f(x_1)}{l(x)},\frac{f(x_2)}{l(x)}) \\ \end{gather} m(x)=max(m(x1),m(x2))f(x1)=em(x1)m(x)f(x1)f(x2)=em(x2)m(x)f(x2)l(x)=f(x1)+f(x2)softmax(x1,x2)=(l(x)f(x1),l(x)f(x2))

Flash Decoding++为了节约同步 m(x) 的时间,选择使用统一的最大值 ϕ 来代替。softmax 公式替换为
s o f t m a x ( x i ) = e x i − ϕ ∑ j e x j − ϕ softmax(x_i) = \frac{e^{x_i-ϕ}}{\sum_j{e^{x_j-ϕ }}} softmax(xi)=jexjϕexiϕ
设置 ϕ 的目的是为了对 softmax 后的结果进行放缩,结果太大可能会溢出 float32,太小又可能出现精度问题影响效果,所以 ϕ 应满足 a < x i − ϕ < b 。 a < x_i-ϕ < b。 a<xiϕ<b在实际推理过程中我们需要根据不同的模型和想要的精度设置对应的 a 、 ϕ 、 b a、ϕ、b aϕb。以 llama-7B 为例,作者发现 99.99%的 x i x_i xi位于[-16.8,6.5] ,ϕ 就可以在这个区间内进行选择。

如果块间差距过大可能会导致不存在 ϕ 使得 a < x i − ϕ < b a < x_i-ϕ < b a<xiϕ<b 都被满足,这时该策略会选择退化成 Flash Decoding 的动态合并 softmax,确保不会出现误差

Flat GEMM Optimization with Double Buffering

假设 GEMM 中两个相乘的矩阵大小分别为 M * K 和 K * N,同时每个 GEMM Tile 会对 K * N 的矩阵进行分块,每块大小为 Bn * Bk (不足则进行填充),那么每个 GEMM Tile的计算量为 2MBn*Bk,内存访问量为 M * Bk + Bn * Bk,共有 N * K / Bn * Bk 块。算上把乘法结果写入的内存访问,整个 GEMM 过程中计算与内存的比值为
2 ∗ M ∗ B n ∗ B k ∗ N ∗ K B n ∗ B k ( M ∗ B k + B n ∗ B k ) ∗ N ∗ K B n ∗ B k + M ∗ N = 2 ∗ M ∗ K K + M ∗ K B n + M \begin{align} \frac{2 * M * B_n * B_k * \frac{N * K}{B_n * B_k} }{(M*B_k + B_n*B_k)*\frac{N * K}{B_n * B_k} + M*N} &=\frac{2*M*K}{K+\frac{M*K}{B_n}+M} \end{align} (MBk+BnBk)BnBkNK+MN2MBnBkBnBkNK=K+BnMK+M2MK
由于 GEMM 运算的并行度为 N/Bn ,因此计算与内存的比值与 Bn 成正比,并行度与 Bn 成反比,这就让计算和内存成为同时制约 GEMM 速度的两个要素。

Tensor Core在进行 GEMM 时,会将 M padding 到 64 去减少内存访问的延时,但这会浪费大量的计算资源。本文选择将这个数值调整到 8,同时为了解决内存访问的问题,增加了 Double Buffering ,将共享的内存分成两个 buffer ,一个用于当前的 GEMM 计算,另一个加载下一次 GEMM 需要的数据。
double buffering

Heuristic Dataflow with Hardware Resource Adaption

第三点其实是对第二点的补充,作者说影响 LLM 推理的因素有很多,像第二点优化 Flat GEMM的操作实测下来也不一定是最优的,还需要具体情况具体分析,但幸运的是 LLM 中 GEMM 的结构都比较类似,以 llama-7b 为例,一共只有4种情况:
在这里插入图片描述

那就好办很多了,直接每个都测一遍。于是作者测了 FaseGEMV、flat GEMM(本文)、CUTLASS 三种方法在不同 M N K 下的速度,并且在实际推理中动态的切换:
在这里插入图片描述

结论

通篇介绍的三种方法并没有对整个推理过程做大的改动,更像是针对 llama-7B 做了特殊的代码优化,并且都存在使用条件,参数的设置都需要大量实验来确定,引入了不少选择和判断。直观感觉实用性不大,这点加速效果并不难达到,针对任何一个主流的推理框架,为某个模型单独的设计 kernel 和并行化策略应该都有这样的效果。但其中工程性的手段值得借鉴,在不改变 decoding 结构的前提下,能做的优化手段确实非常有限了。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值