Flash Attention推理性能探究

目录

1 背景

2 MHA

2.1 计算流程

2.2 Flash Attention

2.3 Flash Attention v2

3 性能

3.1 测试条件

3.2 Prefill推理性能

3.2.1 Seq

3.2.2 Batch

3.3 Decoding推理性能

3.3.1 Seq

3.3.2 Batch

3.4 Hybrid推理性能

4 其他

4.1 GQA/MQA推理场景

4.2 混合推理场景

4.3 ALiBi推理场景


1 背景

Attention机制自Transformer发扬光大之后,在LLM(Large Language Model)中同样大放异彩。然而由于Softmax的计算限制,MHA(Multi Head Attention)的计算过程长期处于严重Memory Bound状态。Flash Attention基于Softmax的数学特性,将MHA的计算融合成一个算子,并采用计算和高速SRAM访存换取低速HBM访存的策略,缓解了Memory Bound压力,大幅提高了MHA的计算速度。

本文基于Flash Attention和Flash Attention v2的C++接口,探究两者计算流程的差异对MHA推理性能的影响。

2 MHA

2.1 计算流程

MHA中的Self Attention的计算流程如上图所示,可以分为以下三个步骤。

O = Softmax(Q * K^T) * V

Step1: S = Q * K^T
Step2: P = Softmax(S)
Step3: O = P * V

类似地,MHA中的Q、K、V和O的维度如下所示,Step1计算batch * hq个矩阵乘,每个矩阵乘的维度是(sq * d)*(d * sk),得到S,Step2经过Softmax计算得到P,Step3计算batch * hq个矩阵乘,每个矩阵乘的维度是(sq * sk)*(sk * d),得到O。

  • Q:total_q * hq * dim

  • K:total_k * hk * dim

  • V:total_k * hk * dim

  • O:total_q * hq * dim

MHA的CPU实现代码如下,为保证精度,中间计算结果全部使用float,源码在flash_attention_inference

void mha_cpu(Tensor<half> *Q, Tensor<half> *K, Tensor<half> *V, Tensor<half> *O, Tensor<int> *cu_seq_q,
             Tensor<int> *cu_seq_k, size_t max_seq_k, bool is_causal, bool is_alibi) {
    size_t total_q = Q->getShape()[0];
    size_t head_q = Q->getShape()[1];
    size_t dim = Q->getShape()[2];
    size_t head_k = K->getShape()[1];
    size_t batch = cu_seq_q->getShape()[0] - 1;

    FAI_CHECK_EQ(head_q % head_k, 0);
    const size_t head_ratio = head_q / head_k;

    half *q_ptr = Q->getHostPtr();
    half *k_ptr = K->getHostPtr();
    half *v_ptr = V->getHostPtr();
    half *o_ptr = O->getHostPtr();

    int *cu_seq_q_ptr = cu_seq_q->getHostPtr();
    int *cu_seq_k_ptr = cu_seq_k->getHostPtr();

    // S = Q * K^T
    Tensor<float> *S = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor S");
    FAI_CHECK(S);
    float *s_ptr = S->getHostPtr();
    for (size_t b = 0; b < batch; ++b) {
        size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
        size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
        size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
        size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
        for (size_t h = 0; h < head_q; ++h) {
            size_t h_k = h / head_ratio;
            for (size_t sq = 0; sq < seq_q; ++sq) {
                for (size_t sk = 0; sk < seq_k; ++sk) {
                    float acc = 0.0;
                    for (size_t d = 0; d < dim; ++d) {
                        acc += __half2float(q_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d]) *
                               __half2float(k_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
                    }
                    s_ptr[sum_seq_q * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = acc;
                }
            }
        }
    }

    // P = Softmax(S)
    Tensor<float> *P = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor P");
    FAI_CHECK(P);
    float *p_ptr = P->getHostPtr();
    float scale = 1.0 / std::sqrt(dim);
    for (size_t b = 0; b < batch; ++b) {
        size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
        size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
        size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
        size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
        size_t row_shift = seq_k - seq_q;
        for (size_t h = 0; h < head_q; ++h) {
            float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0;
            for (size_t sq = 0; sq < seq_q; ++sq) {
                size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k;

                // Max(S)
                std::vector<float> tmp_s(seq_k, 0.0);
                float max_s = -std::numeric_limits<float>::max();
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    tmp_s[sk] = s_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * scale;
                    if (is_alibi && sk < sq + row_shift) {
                        tmp_s[sk] +=
                            (h_slope * (static_cast<int>(sk) - static_cast<int>(sq) - static_cast<int>(row_shift)));
                    }
                    max_s = std::max(max_s, tmp_s[sk]);
                }

                // Sum(S)
                float sum_s = 0.0;
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    tmp_s[sk] = std::exp(tmp_s[sk] - max_s);
                    sum_s += tmp_s[sk];
                }

                // Softmax(S)
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = tmp_s[sk] / sum_s;
                }

                // Causal(S)
                if (is_causal) {
                    for (size_t sk = col_limit; sk < seq_k; ++sk) {
                        p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = 0.0;
                    }
                }
            }
        }
    }

    // O = P * V
    for (size_t b = 0; b < batch; ++b) {
        size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
        size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
        size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
        size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
        for (size_t h = 0; h < head_q; ++h) {
            size_t h_k = h / head_ratio;
            for (size_t sq = 0; sq < seq_q; ++sq) {
                for (size_t d = 0; d < dim; ++d) {
                    float acc = 0.0;
                    for (size_t sk = 0; sk < seq_k; ++sk) {
                        acc += p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] *
                               __half2float(v_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
                    }
                    o_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d] = __float2half(acc);
                }
            }
        }
    }

    if (S) {
        delete S;
        S = nullptr;
    }

    if (P) {
        delete P;
        P = nullptr;
    }
}

2.2 Flash Attention

本文只关注Flash Attention关于MHA的计算流程,其余细节可查看论文和源码。

Flash Attention对于MHA的计算是按batch、head和split-seq_q划分block,在计算Q * K^T时内部warp计算是按K^T矩阵的seq_k维度切分,即每个warp只能得到S矩阵某一行的部分block结果。因此在计算block的softmax时,需要先同步warp。另一方面,最后在计算P * V时,采用split-K方法,对于每个warp计算的中间结果还要进行reduce sum后才能得到O的block结果,在reduce之前仍然需要同步warp。

2.3 Flash Attention v2

Flash Attention v2对于MHA的计算也是按batch、head和split-seq_q划分block,但在计算Q * K^T时内部warp计算是按Q矩阵的seq_q维度切分,即每个warp可以得到S矩阵某一行的所有block结果。因此在计算block的softmax时,不需要同步warp。另一方面,最后在计算P * V时,每个warp也可以直接计算出O的warp结果,不需要reduce,也不需要额外同步warp。

3 性能

3.1 测试条件

代码开源在flash_attention_inference,kernel来自于flash-attention,移除了与推理无关的backward、dropout、bf16和torch依赖等代码,可以很方便地集成到LLM推理场景。此代码在flash attention的基础上还全部支持GQA(Group Query Attention)/ MQA(Multi Query Attention)推理场景、Prefill/Decoding混合推理场景和ALiBi(Attention with Linear Biases)推理场景。

  • MHA:O = Softmax(Q * K^T) * V

  • CUDA:11.8

  • GPU:RTX3090

  • Flash Attention:v1.0.9

  • Flash Attention v2:v2.1.0

  • Cutlass:v3.1.0

  • Head Num:32

  • Head Dim:128

3.2 Prefill推理性能

3.2.1 Seq Len

短序列时,两者性能相当;长序列时,Flash Attention v2性能更好,可以提升60%左右。Flash Attention v2在长序列表现优异的原因主要是减少了block数据之间的多次warp同步。

  • Batch Size:1

  • Seq Q:Seq Len

  • Seq K:Seq Len

3.2.2 Batch Size

Batch Size较小时,Flash Attention v2性能更好;Batch Size较大时,两者性能相当。

  • Batch Size:Batch Size

  • Seq Q:128

  • Seq K:128

3.3 Decoding推理性能

3.3.1 Seq Len

短序列时,两者性能相当;长序列时,Flash Attention性能更好,可以提升100%左右。Flash Attention在长序列表现优异的原因主要是在seq_k维度上的warp分工,提高了计算的并行性。

  • Batch Size:1

  • Seq Q:1

  • Seq K:Seq Len

3.3.2 Batch Size

Batch Size无论大小,Flash Attention性能更好。

  • Batch Size:Batch Size

  • Seq Q:1

  • Seq K:128

3.4 Hybrid推理性能

无论Prefill和Decoding两者的比例如何变化,Flash Attention和Flash Attention v2的性能都比较接近。

  • Batch Size:100

  • Seq Q:128(Prefill)+ 1(Decoding)

  • Seq K:128

4 其他

4.1 GQA/MQA推理场景

已全部支持GQA/MQA推理场景,代码更新在flash_attention_inference

4.2 混合推理场景

已全部支持Prefill和Decoding混合推理场景,代码更新在flash_attention_inference,性能如3.4所示。

4.3 ALiBi推理场景

已全部支持ALiBi推理场景,代码更新在flash_attention_inference

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值