FlashAttention原理:从原始Attention到FlashAttention

一、背景概述

以前的attention加速方法旨在减少attention的计算和内存需要,如sparse-attention、low-rank approximation等,但由于它们主要关注FLOP reduction,且倾向于忽略内存访问的开销,所以都没有达到wall-clock speedup。

FlashAttention比普通attention的HBM(GPU high bandwidth memory)访问量更少,并适用于一系列SRAM大小。
FlashAttention旨在让attention算法变得IO-aware,即仔细计算对于不同级别的memory的读写(如fast GPU on-chip SRAM和relatively slow GPU high bandwidth memory即HBM),各memory层级如下图。这是在普通的深度学习框架(如PyTorch、Tensorflow)所不天然支持的
请添加图片描述

二、具体原理

  1. 宏观目标:compute exact attention with far fewer memory accesses
    具体来说就是避免从HBM中读/写attention矩阵,这需要从两方面入手:
    • computing the softmax reduction without access to the whole input
    • backward时不存储大型的中间矩阵
  2. 做法:提出两个策略
    • 重塑attention计算:将input分成若干blocks,然后incrementally进行softmax reduction(即tilinng)
    • 存储forward中的softmax normalization factor,以便在backward中快速recompute attention on-chip(原来是从HBM中读取中间矩阵)

一、attention是什么

attention的本质:带权特征筛选。注意力机制描述了(序列)元素的加权平均值,其权重是根据输入的query和元素的键值进行动态计算的。

先有seq2seq模型,seq2seq其实就是输入序列通过一个RNN,得到最终编码后,再通过一个RNN,每次输出翻译的单词。如下图
在这里插入图片描述
2015年提出attention,此时attention就被绑定在上面这个seq2seq模型上,2016年提出self-attention,2017年提出transformer。因此attention并不能和transformer划等号,它只是transformer的最核心的零件。

输入的时候是一起全部输入的,输出的时候还是一个一个输出的

经典传统文本翻译模型seq2seq是基于循环神经网络RNN的,而transformer丢掉了循环神经网络,只用到了注意力

transformer

  1. 最基础的注意力:q、k、v
    如下图,给定一张图,什么问题都不问,人对它的本来印象就是k(如第一行的数字,表示的是本来印象),然后提出问题,问图里的动物有哪些。此时q(query)就是“动物”。q和k相乘得出来的就是每个元素的权重(第二行)。本例中右上角其实有个海鸥,但由于一开始根本没有注意到,因此即使乘了q也不会有任何权重。v就是客观上的各个元素,没有任何权重。
    在这里插入图片描述
    k和v都是由x(原本要翻译的这句话中的每一个词向量)分别经过不同矩阵相乘而得到的结果
    在这里插入图片描述
    在这里插入图片描述
    也就是输入向量x,分别乘不同的权重矩阵得到向量V和向量K,然后向量Q乘向量K转置得到矩阵 α \alpha α,最终矩阵 α \alpha α右乘向量V得到向量b
  2. 自注意力:自注意力就是上面的Q、K、V全部都来源于输入x(前面是只有k和v)。如下图,自注意力输入N个,也输出N个。BERT就是自注意力。并行计算的
    在这里插入图片描述
  3. 多头注意力:多头就是有多组Q、K、V。如上图是一组(transformer是8头),然后再来一组,就会生成第二组 b 1 , b 2 , . . . , b N b_1,b_2,...,b_N b1,b2,...,bN。最终生成的多组b,连接在一起就是结果。说白了多头就是多组参数,就是让模型复杂一些、让参数更多一些。
  4. 位置编码:如上述多头注意力计算过程可知,调换任意两个x的位置,各x对应的结果是不变的。因此我们若要考虑位置信息,需要引入位置编码
  5. transformer:
    在这里插入图片描述
    • 左边encoder,右边decoder。(假设任务是中翻英)encoder的作用就是把要翻译的中文进行充分理解,充分获取里面的信息。decoder就是把理解的内容进行翻译版本的输出

二、FlashAttention原理解读

  1. 贡献要点:Fast and Memory-Efficient Exact Attention with IO-Awareness
    • fast:加快模型训练速度
    • memory-efficient:减少显存占用
    • exact attention:和标准attention计算得到的结果完全一致,精度不变
    • IO-awareness:算法是改进IO的效率
  2. 以一个attention的标准计算过程为例进行存储方面的说明:从下图中可看到随着序列长度N的增长,中间矩阵的缓存是以 N 2 N^2 N2进行增长的。而且这些中间矩阵是必须要存的,因为反向传播时需要它们来计算梯度
    请添加图片描述
    pytorch在实际显卡上的attention实现:
    矩阵Q、K、V(N*d,N是序列长度,d是特征维度)存储在HBM中
    • 从HBM加载Q、K到SRAM
    • 计算出 S = Q K T S=QK^T S=QKT
    • 将S写到HBM
    • 将S加载到SRAM
    • 计算 P = s o f t m a x ( S ) P=softmax(S) P=softmax(S)
    • 将P写到HBM
    • 从HBM加载P和V到SRAM
    • 计算 O = P V O=PV O=PV
    • 把O写到HBM
    • 返回O
  3. 模型训练时的两个制约因素
    • Compute-bound:速度瓶颈在于运算。比如对于大的矩阵乘法需要多channel的卷积运算,这些运算需要的数据不多,但是计算很复杂
    • Memory-bound:速度瓶颈在于HBM的读取速度,即算力等待数据。这样的操作包含两类,这些操作都是需要数据很多,但是计算相对简单:
      • 按位操作:Relu、Dropout
      • 规约操作:sum、softmax
    • 对于Memory-bound的优化一般是进行fusion融合操作,通过不对中间结果缓存来减少GBM的访问。但中间结果在方向传播时是有用的,这样不保存的话只能在反向传播时重新计算
  4. memory层级:需要尽可能增加芯片内(SRAM)的访存、减少芯片外的HBM的访存
    请添加图片描述
  5. Flash Attention原理:避免attention matrix从HBM的读写
    • 通过分块计算,融合attention内的操作,不缓存中间结果到HBM,从而加快速度
    • 反向传播时,重新计算中间结果,以此来解决不缓存后梯度无法计算的问题
    • 具体实现:
      请添加图片描述
      请添加图片描述
      请添加图片描述
      如上图, K T K^T KT、V不变,分别读取Q的前、中间、后两行,计算出O的前、中间、后两行(非最终结果);然后取 K T K^T KT和V的后三列、后三行,并分别和Q的前、中、后两行进行计算,附加到O上,得到O的最终结果。

      上述去掉了softmax。Softmax是按行进行的,如下图所示,要让矩阵分块对attention多步进行融合计算得以进行的前提就是要解决softmax的分块问题:
      请添加图片描述
  6. softmax的分块计算:如下图2所示,softmax是可以分块计算的,但是就是不如上面的简单,softmax的分块计算需要额外保存几个变量(即 m ( x 1 ) , m ( x 2 ) , p ( x 1 ) , p ( x 2 ) , l ( x 1 ) , l ( x 2 ) m(x^1), m(x^2), p(x^1), p(x^2), l(x^1), l(x^2) m(x1),m(x2),p(x1),p(x2),l(x1),l(x2))、以及在分块的合并时需要增加一些计算量。但这些计算量相对于减少的IO时间都是非常划算的
    原softmax:
    请添加图片描述
    分块版softmax:
    请添加图片描述
    伪代码:
    请添加图片描述
  7. FlashAttention2的大致思想与FlashAttention1类似,只不过增加了以下一些工程上的优化:
    • 减少了非矩阵乘法计算,可利用TensorCore加速
    • 调整了内外循环,Q为外层循环、KV为内层循环。进一步减少了HBM读写
    • 如果一个block处于矩阵上三角部分(也就是被mask掉的部分),就不进行attention计算,进一步减少计算量

三、FlashAttention更详细的细节剖析

可见这篇文章,写得很好,一篇就足够了。
笔记很详细,见实验室-FlashAttention修改-flashAttention原理

<think>嗯,用户让我介绍如何写一个FlashAttention算子。首先,我需要理解FlashAttention是什么。我记得FlashAttention是一种高效计算注意力机制的算法,由斯坦福团队提出,主要优化了GPU内存的使用,特别是减少对高带宽内存(HBM)的访问,从而加快训练速度。好的,现在需要把这个过程讲清楚,但用户可能想了解具体实现的步骤,而不仅仅是概念。 用户可能有一定的深度学习背景,尤其是对Transformer模型和注意力机制有所了解。他们可能正在实现自己的模型,遇到了性能瓶颈,或者对优化感兴趣。所以需要详细说明实现的关键点,比如分块计算、循环处理、重计算等。 接下来,我需要分解实现步骤。首先,分块处理输入矩阵,将大的Q、K、V矩阵分成小块,这样可以利用GPU的SRAM,减少HBM的访问次数。然后,循环加载这些块到SRAM,计算局部注意力分数,接着处理softmax的分母部分,可能需要在线计算最大值和求和来保持数值稳定。还要处理反向传播时的重计算,避免存储中间结果,从而节省内存。 需要注意的是,用户可能对具体代码实现感兴趣,但直接写代码可能太复杂,而且不同框架(如PyTorch或CUDA)实现方式不同。所以应该重点讲解算法逻辑和关键步骤,而不是具体代码。 另外,内存管理的细节很重要,比如如何分块,块的大小如何确定,这会影响性能。同时,数值稳定性是softmax计算中的关键点,需要解释在线计算max和sum的方法,以及如何合并不同块的结果。 可能还需要比较传统注意力机制和FlashAttention的区别,突出FlashAttention的优势,比如内存效率和计算速度。但用户的问题主要是如何编写这个算子,所以重点放在实现步骤上。 最后,要确保回答结构清晰,分步骤说明,每个步骤的关键点,可能遇到的挑战,比如如何处理块间的依赖,如何高效地管理内存,以及如何实现反向传播的重计算。需要提醒用户在实际实现中需要处理多线程、内存对齐等底层优化,可能建议参考原始论文或现有开源实现,如Tri Dao的代码。</think>好的,我将分步骤为您介绍如何实现一个简化版的FlashAttention算子: --- ### 一、FlashAttention核心思想 1. **目标**:高效计算注意力矩阵 $QK^TV$ 2. **核心优化**: - 通过分块计算(tiling)减少GPU HBM访问次数 - 在线计算softmax避免存储中间矩阵 - 反向传播时重计算(Recompute)代替存储 --- ### 二、关键实现步骤 #### 步骤1:矩阵分块 ```python # 将Q, K, V分块 (示例维度) Q = [B, H, N, d] # Batch, Head, SeqLen, Dim K = [B, H, N, d] V = [B, H, N, d] # 分块大小根据SRAM容量确定 block_size = 64 # 典型值 ``` #### 步骤2:外循环处理K/V块 ```python for j in range(0, N, block_size): Kj = K[:, :, j:j+block_size, :] # 加载K的块到SRAM Vj = V[:, :, j:j+block_size, :] ``` #### 步骤3:内循环处理Q块 ```python for i in range(0, N, block_size): Qi = Q[:, :, i:i+block_size, :] # 加载Q的块到SRAM # 计算局部注意力分数 Sij = einsum('bhid,bhjd->bhij', Qi, Kj) / sqrt(d) ``` #### 步骤4:在线Softmax计算 ```python # 计算每块的最大值和求和项 m_ij = max(Sij, dim=-1) # 当前块的最大值 p_ij = exp(Sij - m_ij) # 稳定化指数 l_ij = sum(p_ij, dim=-1) # 分母部分 # 合并到全局统计量 m_new = max(m_prev, m_ij) l_new = exp(m_prev - m_new) * l_prev + exp(m_ij - m_new) * l_ij ``` #### 步骤5:增量更新输出 ```python # 更新输出矩阵 O_i = (exp(m_prev - m_new) * l_prev * O_prev + exp(m_ij - m_new) * p_ij @ Vj) / l_new ``` #### 步骤6:反向传播重计算 ```python # 前向时不存储S矩阵,反向时重新计算 def backward(ctx, grad_output): recompute_Sij() # 重新分块计算注意力分数 # 计算梯度... ``` --- ### 三、关键优化技术 1. **内存管理** - 将计算分解为多个块,确保每个块能在SRAM中完成计算 - 典型内存消耗:$O(N)$ → $O(\sqrt{N})$ 2. **数值稳定性** - 在线计算softmax分母(使用对数域合并) - 增量更新输出矩阵 3. **硬件适配** - 使用CUDA kernel实现 - 利用Tensor Cores加速矩阵乘 --- ### 四、与传统注意力对比 | 特性 | 传统Attention | FlashAttention | |---------------|------------------|-----------------| | 内存复杂度 | $O(N^2)$ | $O(N)$ | | HBM访问次数 | $Ω(N^2)$ | $O(N^2/M)$ | | 最大序列长度 | ~4k tokens | ~64k tokens | --- ### 五、实现建议 1. 参考官方CUDA实现:https://github.com/HazyResearch/flash-attention 2. 使用Triton语言简化开发: ```python @triton.jit def flash_attention_kernel(...) ``` 3. 注意内存对齐和bank conflict避免 需要更具体的代码实现细节,可以进一步讨论某个模块的实现!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值