Introduction
- 作者提出 Flash-Decoding,通过在 FlashAttention 的基础上增加 key/values sequence length 维度上的并行,有效提升了 FlashAttention 在 LLM 长序列推理场景下的 GPU 利用率
Method
- Motivation. decoding 阶段由于 query length 为 1,因此主要的访存瓶颈在于加载 KV cache 而非 prefill 阶段的加载 attetion matrix,标准 attention 实现和 FlashAttention 的访存开销均为
O
(
N
d
)
O(Nd)
O(Nd). 并且 FlashAttention 只在 bacth size、number of heads 和 query sequence length 维度上做了并行,对于 decoding 阶段,如果序列长度很长,那么通常 bacth size 也会比较小,FlashAttention 的并行度就很低,GPU 利用率也会很低
- Flash-Decoding. Flash-Decoding 在 keys/values sequance length 上也做了并行。首先将 keys/values 划分成若干 chunks,然后调用多个 thread blocks 并行计算 query 和每个 chunk 的 attetion 结果,kernel 还是采用 FlashAttention,并且还要为每个 chunk 记录 log-sum-exp of the attention values. 这样做可以把加载 KV cache 的任务分配到不同 thread blocks 上,提升 HBM 带宽利用率,从而提升 GPU 利用率. 由于不同 thread blocks 间无法直接通信,因此最后还需要将中间结果写回 HBM,并调用一个单独的 kernel 对所有 chunks 的计算结果做 reduce,也就是根据 log-sum-exp 对所有 chunks 的结果做加权和得到最终的计算结果
Experiments
- Benchmarks on CodeLlama 34B.
- Component-level micro-benchmarks.