LLM 推理加速:decode 阶段的 Attention 在 GPU 上的优化

LLM推理加速:decode阶段的Attention在GPU上的优化

作者:董纪莹

随着大语言模型(Large Language Models,LLMs)在各领域的广泛应用,如何以低成本构建高吞吐、低延迟的推理服务成为了一个紧迫的问题。考虑到 LLM 在 GPU 上推理时参数量和计算量较大以致于单流执行就可以充分利用 GPU 资源,我们可以把 LLM 的推理延时分解到 kernel level,因此,进一步的,不考虑时间占比小的 kernel 计算后,LLM 的延时优化也就相应的分解成 GEMM 和 Attention 的 kernel 优化。

RTP-LLM 是阿里巴巴智能引擎团队开发的大模型推理加速引擎,作为一个高性能的大模型推理解决方案,它已被广泛应用于阿里内部。在这篇文章里,我们将基于 RTP-LLM 的实践,介绍 decode 阶段的 Attention 在 GPU 上是如何优化的。

一、背景

我们比较熟悉的 Attention 计算如下图所示,包含 Q 与 K 相乘,其结果在 mask 后做 softmax,然后与 V 相乘,得到 Attention 的结果。在 LLM 推理的 decode 阶段,由于加入 KV Cache 优化,一次迭代只需要计算新增的一个 token,因此计算可以变化为当前 step 的 Q(seq == 1)与 K Cache、V Cache 做计算。







计算过程中各 tensor 的 shape 可以表示为:、

参数的解释如下表:

在本文的分析中,我们考虑简单的 Multi Head Attention 实现,即 H == H_kv。

我们希望以一个 kernel 实现上图的计算。出于性能考虑,将前一步的 BiasAdd,Rotary Embedding 也一起融合。因此这个 kernel 接受的输入是经过 QKV GEMM 的 Q、K、V,在 kernel 中完成 BiasAdd,然后 Q 和 K 会一起做 Rotary Embedding。当前的 K 和 V 会分别与之前计算得到的 KV Cache 做拼接,扩展成(B, H, S, D)的 KV Cache。然后 Q 与 K Cache 相乘,得到的结果在 S 维计算 SoftMax,再与 V Cache 相乘,得到最后的输出。

简化的代码示例如下:

#(B, 3, H, D) -> 3 * (B, H, 1, D)Q, K, V = add(QKV_buffer, bias)#(B, H, 1, D) -> (B, H, 1, D)Q, K = rotary_embedding(Q, K)#(B, H, 1, D) -> (B, H, S, D) K, V = concat(past_KV, K, V)#(B, H, 1, D) * (B, H, S, D) -> (B, H, 1, S)res = matmul(Q, K)/ sqrt(self.head_dim)#(B, H, 1, S) -> (B, H, 1, S)res = =softmax(res, dim=-1)#(B, H, 1, S) * (B, H, S, D) -> (B, H, 1, D)out = matmul(res, V)

在整个计算过程中,BiasAdd、Rotary Embedding 相对计算量较小,对 kernel 的 latency 影响较小,因此下文省略这一部分的分析。

二、计算分析

我们以当前的 TensorRT-LLM 中 Masked Multi Head Attention(MMHA)的实现为例,分析当前的 MMHA 是怎么实现高性能。

涉及到 GPU 并行计算,我们首先需要考虑的是任务划分。对于这个场景,任务划分实际上是清晰的:B 和 H 是并行维度,在执行过程中的 Q*K 和 QK*V,都可以理解成一个 batch size = B * H 的 Batch GEMV。而 SoftMax 又是一个 Reduce 操作,因此单个 GEMV 的计算最好尽量在一个 block 内完成。因此,MMHA 比较基础的任务划分大概是:

dim3 grid(B, H, 1);dim3 block(THREAD_PER_BLOCK, 1, 1);

这里的 THREAD_PER_BLOCK 是指每个 block 用多少 threads 来完成一个 head 在 S 上的计算。通常更多的 threads 会更提高每个 SM 的 active warps 以更好的利用计算资源,增加 load 指令以提高数据 load 效率,因此我们希望 THREAD_PER_BLOCK 越大越好(最好接近 1024)。但由于 kernel 整体计算逻辑较为复杂,寄存器用量较大,threads 可能会收到寄存器总量的限制;且在寄存器总量的限制下,我们可以简单的认为每个 SM 上只有一个 active block。

基于这种划分,我们继续考虑每个 block 是如何计算。传入 kernel 的 QKV buffer 实际的 layout 是(B,3, H,D),在 TensorRT-LLM 的实现中,会先 load 当前 step 的 Q 和 K 并计算 BiasAdd 和 ROPE,并将这一步得到的 K Cache 写回 global buffer。完成这些计算后,因为数据还在寄存器中,会直接计算对应的 QK dot。由于这些计算的耗时较短,我们略过这一部分分析,直接看看 TensorRT-LLM 是怎么计算 Q * K Cache 的。

Q 乘 K Cache 的计算在 D 上累加。假设我们用 half 存 KV Cache,用 float 做乘累加,为了保证 load 效率,每个 thread 会 load 连续的 16bytes 数据,也就是 8 个 elements。对于常见的 D==128 来说,需要 16 个 threads 完成一个 head 的计算。可以认为给 block 中的 threads 进行了分组,每组 16 个 threads 负责一个 head 的计算,其中每个 threads 读 8 个 elements,并完成这 8 个 elements 对应的乘累加,然后这组 threads 间通过 warp 内的 shuffle 完成当前 head 的计算,并将计算结果存到 smem 中。组和组在 S 上展开。





接下来计算 SoftMax,由于前面的计算保证了 SoftMax 需要的输入都在当前 block 内的 smem 中,通过 Block Reduce Max 和 Block Reduce Sum 就可以完成 SoftMax 的计算。

乘 V Cache 的计算思路与上文乘 K Cache 非常类似,略有不同的是这一步计算需要在 S 上累加。依然将 threads 分组,每组 16 个 threads 负责一个 head, 每个 thread 负责 8 个 elements 的计算。由于需要在 S 上累加,因此每个 thread 需要保存当前所计 GPUsde 算的 8 个 elements 的部分累加和。最后借助 smem,将不同 threads 上的部分和累加,得到 Attention 的输出。





在计算过程中,qk dot 除了 hfma 计算外,也可以调用 hmma 来完成单个 head 的计算。但由于 kernel 的性能瓶颈在访存上,dot 用哪种计算方式对性能的影响不大;我们的测试也验证了这个结论。

上文的分析中依然省略了一些细节。具体的,比如我们现在通常用 paged KV Block Array 来存储 KV Cache,也就是 KV Cache 可以在 S 维度上不连续,以便在 S 不断增长时动态的分配 buffer。但 paged 的存储并不改变 D 维的连续,因此也不影响上文的分析。此外,每个 thread 在 load KV Cache 时会多 load 一部分存进本地的寄存器,以尽可能的将 load 数据与 dot 计算 overlap。

主流框架如 vllm,xformers 等对 MMHA 的实现和优化思路都是比较类似的,仅在细节处略有差异。TensorRT-LLM 在 mmha 外还实现了 XQA 以继续优化 decode 阶段 Attention 的计算,但由于代码未开源,本文也不做分析。

三、改进与优化

当然上文分析到的简单优化在实际应用中还是不那么够用的,最常见的就是小 B 和长 S 场景。

考虑到实际的 GPU 资源,如 A100 有 108 个 SM,且每个 SM 上只有一个 block(也就是只计算一个 head),当 B * H 恰好占满 108(或 108 的整数倍)个 SM 时,可以认为占用率是比较高的。以 7B 模型,或者 72B 模型 2TP 举例,H = 32,当 B = 3 时,占用率是 88.9%;而当 B = 4 时,就会因必须打两轮而带来占用率的下降到 59%;当 B = 1 时,占用率就会低到 30%了。这个时候如果 S 比较大,我们就会发现,大部分的 device 资源还空闲着,也不得不一起等待部分 SM 完成一个时间很长的计算。

针对这种情况,我们把 S 也分配到 grid dim 上,资源分配也就改为:

dim3 grid(B, H, S_tile);dim3 block(THREAD_PER_BLOCK, 1, 1);

在这种任务划分下,结合上文分析,假设长 seq 每个 SM 上仅有一个 active block,则 waves 可以计算为:





当 waves 越接近 ceil 值,意味着 device occupancy 会越高。在小 B 大 S 的场景下,如果在 S 切分,也就是 S_tile > 1,有利于增加 occupancy。在这种情况下,S_tile 个 block 共同完成一个 head 在 S 上的计算,每个 block 负责 S / S_tile 的计算,block 间的 reduce 通过开辟额外的 global buffer 来完成。这种模式下,新增的 global 读写会带来有额外的耗时,但因为增加了 device occupancy,因此在小 B 大 S 的场景下有明显的性能提升。这也就是 flash decoding 的思路,且在各框架均有支持。

除了性能的考虑外,超长 seq 也必须走进这种实现。由于 Q * K 的结果需要在 S 上做 reduce,也就是 smem 需要存下对应大小的中间数据,根据 kernel 实现,输入类型是 half,以 float 累加,可以估计算为 6 * S。而根据 A100 每个 SM 实际可用 smem 是 163KB 计算,最大可支持的 S 在 27K 左右。当输入大于这个值时,我们必须在 seq 做切分,以保证 kernel 的计算。

另一种需要做不同的任务划分的场景是 GQA。在 GQA 的计算下,每个 head 的 KV Cache 会对应于多个 head 的 Q,为了避免 KV Cache 的重复 load,资源分配应该改为,并基于此做计算上的调整。

dim3 grid(B, H_kv, S_tile);dim3 block(THREAD_PER_BLOCK, 1, 1);

除了优化任务划分,MMHA 的优化还可以在以下方面继续展开:

1)优化寄存器用量可能达到更高的占用率(可以在一个 SM 上 launch 多个 block 或者增大每个 block 的 threads);

2)继续调整 KV Cache 的 load 行为,让计算和数据读取进一步 overlap 以缓解 memory bound 的场景;

3)在大 B 加上 GQA,Attention 会走到 compute bound,需要调整计算模式以更好的利用 tensor core 加速计算等等。

我们将持续探索和实践,以更灵活、更具拓展性的优化策略来面对日益多样化和复杂的应用场景。优化后的 kernel 会开源在 RTP-LLM 中,欢迎大家交流共建。

参考链接

[01] TensorRT-LLM

https://github.com/NVIDIA/TensorRT-LLM

[02] vllm

https://github.com/vllm-project/vllm

[03] xformers

https://github.com/facebookresearch/xformers

[04] flash decoding

https://crfm.stanford.edu/2023/10/12/flashdecoding.html

[03] RTP-LLM

https://github.com/alibaba/rtp-llm

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值