目录
1 背景
Attention机制自Transformer发扬光大之后,在LLM(Large Language Model)中同样大放异彩。然而由于Softmax的计算限制,MHA(Multi Head Attention)的计算过程长期处于严重Memory Bound状态。Flash Attention基于Softmax的数学特性,将MHA的计算融合成一个算子,并采用计算和高速SRAM访存换取低速HBM访存的策略,缓解了Memory Bound压力,大幅提高了MHA的计算速度。
本文基于Flash Attention和Flash Attention v2的C++接口,探究两者计算流程的差异对MHA推理性能的影响。
2 MHA
2.1 计算流程
MHA中的Self Attention的计算流程如上图所示,可以分为以下三个步骤。
O = Softmax(Q * K^T) * V
Step1: S = Q * K^T
Step2: P = Softmax(S)
Step3: O = P * V
类似地,MHA中的Q、K、V和O的维度如下所示,Step1计算batch * hq个矩阵乘,每个矩阵乘的维度是(sq * d)*(d * sk),得到S,Step2经过Softmax计算得到P,Step3计算batch * hq个矩阵乘,每个矩阵乘的维度是(sq * sk)*(sk * d),得到O。
-
Q:total_q * hq * dim
-
K:total_k * hk * dim
-
V:total_k * hk * dim
-
O:total_q * hq * dim
MHA的CPU实现代码如下,为保证精度,中间计算结果全部使用float,源码在flash_attention_inference。
void mha_cpu(Tensor<half> *Q, Tensor<half> *K, Tensor<half> *V, Tensor<half> *O, Tensor<int> *cu_seq_q,
Tensor<int> *cu_seq_k, size_t max_seq_k, bool is_causal, bool is_alibi) {
size_t total_q = Q->getShape()[0];
size_t head_q = Q->getShape()[1];
size_t dim = Q->getShape()[2];
size_t head_k = K->getShape()[1];
size_t batch = cu_seq_q->getShape()[0] - 1;
FAI_CHECK_EQ(head_q % head_k, 0);
const size_t head_ratio = head_q / head_k;
half *q_ptr = Q->getHostPtr();
half *k_ptr = K->getHostPtr();
half *v_ptr = V->getHostPtr();
half *o_ptr = O->getHostPtr();
int *cu_seq_q_ptr = cu_seq_q->getHostPtr();
int *cu_seq_k_ptr = cu_seq_k->getHostPtr();
// S = Q * K^T
Tensor<float> *S = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor S");
FAI_CHECK(S);
float *s_ptr = S->getHostPtr();
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
for (size_t h = 0; h < head_q; ++h) {
size_t h_k = h / head_ratio;
for (size_t sq = 0; sq < seq_q; ++sq) {
for (size_t sk = 0; sk < seq_k; ++sk) {
float acc = 0.0;
for (size_t d = 0; d < dim; ++d) {
acc += __half2float(q_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d]) *
__half2float(k_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
}
s_ptr[sum_seq_q * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = acc;
}
}
}
}
// P = Softmax(S)
Tensor<float> *P = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor P");
FAI_CHECK(P);
float *p_ptr = P->getHostPtr();
float scale = 1.0 / std::sqrt(dim);
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
size_t row_shift = seq_k - seq_q;
for (size_t h = 0; h < head_q; ++h) {
float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0;
for (size_t sq = 0; sq < seq_q; ++sq) {
size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k;
// Max(S)
std::vector<float> tmp_s(seq_k, 0.0);
float max_s = -std::numeric_limits<float>::max();
for (size_t sk = 0; sk < col_limit; ++sk) {
tmp_s[sk] = s_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * scale;
if (is_alibi && sk < sq + row_shift) {
tmp_s[sk] +=
(h_slope * (static_cast<int>(sk) - static_cast<int>(sq) - static_cast<int>(row_shift)));
}
max_s = std::max(max_s, tmp_s[sk]);
}
// Sum(S)
float sum_s = 0.0;
for (size_t sk = 0; sk < col_limit; ++sk) {
tmp_s[sk] = std::exp(tmp_s[sk] - max_s);
sum_s += tmp_s[sk];
}
// Softmax(S)
for (size_t sk = 0; sk < col_limit; ++sk) {
p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = tmp_s[sk] / sum_s;
}
// Causal(S)
if (is_causal) {
for (size_t sk = col_limit; sk < seq_k; ++sk) {
p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = 0.0;
}
}
}
}
}
// O = P * V
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
for (size_t h = 0; h < head_q; ++h) {
size_t h_k = h / head_ratio;
for (size_t sq = 0; sq < seq_q; ++sq) {
for (size_t d = 0; d < dim; ++d) {
float acc = 0.0;
for (size_t sk = 0; sk < seq_k; ++sk) {
acc += p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] *
__half2float(v_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
}
o_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d] = __float2half(acc);
}
}
}
}
if (S) {
delete S;
S = nullptr;
}
if (P) {
delete P;
P = nullptr;
}
}
2.2 Flash Attention
本文只关注Flash Attention关于MHA的计算流程,其余细节可查看论文和源码。
Flash Attention对于MHA的计算是按batch、head和split-seq_q划分block,在计算Q * K^T时内部warp计算是按K^T矩阵的seq_k维度切分,即每个warp只能得到S矩阵某一行的部分block结果。因此在计算block的softmax时,需要先同步warp。另一方面,最后在计算P * V时,采用split-K方法,对于每个warp计算的中间结果还要进行reduce sum后才能得到O的block结果,在reduce之前仍然需要同步warp。
2.3 Flash Attention v2
Flash Attention v2对于MHA的计算也是按batch、head和split-seq_q划分block,但在计算Q * K^T时内部warp计算是按Q矩阵的seq_q维度切分,即每个warp可以得到S矩阵某一行的所有block结果。因此在计算block的softmax时,不需要同步warp。另一方面,最后在计算P * V时,每个warp也可以直接计算出O的warp结果,不需要reduce,也不需要额外同步warp。
3 性能
3.1 测试条件
代码开源在flash_attention_inference,kernel来自于flash-attention,移除了与推理无关的backward、dropout、bf16和torch依赖等代码,可以很方便地集成到LLM推理场景。此代码在flash attention的基础上还全部支持GQA(Group Query Attention)/ MQA(Multi Query Attention)推理场景、Prefill/Decoding混合推理场景和ALiBi(Attention with Linear Biases)推理场景。
-
MHA:O = Softmax(Q * K^T) * V
-
CUDA:11.8
-
GPU:RTX3090
-
Flash Attention:v1.0.9
-
Flash Attention v2:v2.1.0
-
Cutlass:v3.1.0
-
Head Num:32
-
Head Dim:128
3.2 Prefill推理性能
3.2.1 Seq Len
短序列时,两者性能相当;长序列时,Flash Attention v2性能更好,可以提升60%左右。Flash Attention v2在长序列表现优异的原因主要是减少了block数据之间的多次warp同步。
-
Batch Size:1
-
Seq Q:Seq Len
-
Seq K:Seq Len
3.2.2 Batch Size
Batch Size较小时,Flash Attention v2性能更好;Batch Size较大时,两者性能相当。
-
Batch Size:Batch Size
-
Seq Q:128
-
Seq K:128
3.3 Decoding推理性能
3.3.1 Seq Len
短序列时,两者性能相当;长序列时,Flash Attention性能更好,可以提升100%左右。Flash Attention在长序列表现优异的原因主要是在seq_k维度上的warp分工,提高了计算的并行性。
-
Batch Size:1
-
Seq Q:1
-
Seq K:Seq Len
3.3.2 Batch Size
Batch Size无论大小,Flash Attention性能更好。
-
Batch Size:Batch Size
-
Seq Q:1
-
Seq K:128
3.4 Hybrid推理性能
无论Prefill和Decoding两者的比例如何变化,Flash Attention和Flash Attention v2的性能都比较接近。
-
Batch Size:100
-
Seq Q:128(Prefill)+ 1(Decoding)
-
Seq K:128
4 其他
4.1 GQA/MQA推理场景
已全部支持GQA/MQA推理场景,代码更新在flash_attention_inference。
4.2 混合推理场景
已全部支持Prefill和Decoding混合推理场景,代码更新在flash_attention_inference,性能如3.4所示。
4.3 ALiBi推理场景
已全部支持ALiBi推理场景,代码更新在flash_attention_inference。