FlashAttention原理:从原始Attention到FlashAttention

一、背景概述

以前的attention加速方法旨在减少attention的计算和内存需要,如sparse-attention、low-rank approximation等,但由于它们主要关注FLOP reduction,且倾向于忽略内存访问的开销,所以都没有达到wall-clock speedup。

FlashAttention比普通attention的HBM(GPU high bandwidth memory)访问量更少,并适用于一系列SRAM大小。
FlashAttention旨在让attention算法变得IO-aware,即仔细计算对于不同级别的memory的读写(如fast GPU on-chip SRAM和relatively slow GPU high bandwidth memory即HBM),各memory层级如下图。这是在普通的深度学习框架(如PyTorch、Tensorflow)所不天然支持的
请添加图片描述

二、具体原理

  1. 宏观目标:compute exact attention with far fewer memory accesses
    具体来说就是避免从HBM中读/写attention矩阵,这需要从两方面入手:
    • computing the softmax reduction without access to the whole input
    • backward时不存储大型的中间矩阵
  2. 做法:提出两个策略
    • 重塑attention计算:将input分成若干blocks,然后incrementally进行softmax reduction(即tilinng)
    • 存储forward中的softmax normalization factor,以便在backward中快速recompute attention on-chip(原来是从HBM中读取中间矩阵)

一、attention是什么

attention的本质:带权特征筛选。注意力机制描述了(序列)元素的加权平均值,其权重是根据输入的query和元素的键值进行动态计算的。

先有seq2seq模型,seq2seq其实就是输入序列通过一个RNN,得到最终编码后,再通过一个RNN,每次输出翻译的单词。如下图
在这里插入图片描述
2015年提出attention,此时attention就被绑定在上面这个seq2seq模型上,2016年提出self-attention,2017年提出transformer。因此attention并不能和transformer划等号,它只是transformer的最核心的零件。

输入的时候是一起全部输入的,输出的时候还是一个一个输出的

经典传统文本翻译模型seq2seq是基于循环神经网络RNN的,而transformer丢掉了循环神经网络,只用到了注意力

transformer

  1. 最基础的注意力:q、k、v
    如下图,给定一张图,什么问题都不问,人对它的本来印象就是k(如第一行的数字,表示的是本来印象),然后提出问题,问图里的动物有哪些。此时q(query)就是“动物”。q和k相乘得出来的就是每个元素的权重(第二行)。本例中右上角其实有个海鸥,但由于一开始根本没有注意到,因此即使乘了q也不会有任何权重。v就是客观上的各个元素,没有任何权重。
    在这里插入图片描述
    k和v都是由x(原本要翻译的这句话中的每一个词向量)分别经过不同矩阵相乘而得到的结果
    在这里插入图片描述
    在这里插入图片描述
    也就是输入向量x,分别乘不同的权重矩阵得到向量V和向量K,然后向量Q乘向量K转置得到矩阵 α \alpha α,最终矩阵 α \alpha α右乘向量V得到向量b
  2. 自注意力:自注意力就是上面的Q、K、V全部都来源于输入x(前面是只有k和v)。如下图,自注意力输入N个,也输出N个。BERT就是自注意力。并行计算的
    在这里插入图片描述
  3. 多头注意力:多头就是有多组Q、K、V。如上图是一组(transformer是8头),然后再来一组,就会生成第二组 b 1 , b 2 , . . . , b N b_1,b_2,...,b_N b1,b2,...,bN。最终生成的多组b,连接在一起就是结果。说白了多头就是多组参数,就是让模型复杂一些、让参数更多一些。
  4. 位置编码:如上述多头注意力计算过程可知,调换任意两个x的位置,各x对应的结果是不变的。因此我们若要考虑位置信息,需要引入位置编码
  5. transformer:
    在这里插入图片描述
    • 左边encoder,右边decoder。(假设任务是中翻英)encoder的作用就是把要翻译的中文进行充分理解,充分获取里面的信息。decoder就是把理解的内容进行翻译版本的输出

二、FlashAttention原理解读

  1. 贡献要点:Fast and Memory-Efficient Exact Attention with IO-Awareness
    • fast:加快模型训练速度
    • memory-efficient:减少显存占用
    • exact attention:和标准attention计算得到的结果完全一致,精度不变
    • IO-awareness:算法是改进IO的效率
  2. 以一个attention的标准计算过程为例进行存储方面的说明:从下图中可看到随着序列长度N的增长,中间矩阵的缓存是以 N 2 N^2 N2进行增长的。而且这些中间矩阵是必须要存的,因为反向传播时需要它们来计算梯度
    请添加图片描述
    pytorch在实际显卡上的attention实现:
    矩阵Q、K、V(N*d,N是序列长度,d是特征维度)存储在HBM中
    • 从HBM加载Q、K到SRAM
    • 计算出 S = Q K T S=QK^T S=QKT
    • 将S写到HBM
    • 将S加载到SRAM
    • 计算 P = s o f t m a x ( S ) P=softmax(S) P=softmax(S)
    • 将P写到HBM
    • 从HBM加载P和V到SRAM
    • 计算 O = P V O=PV O=PV
    • 把O写到HBM
    • 返回O
  3. 模型训练时的两个制约因素
    • Compute-bound:速度瓶颈在于运算。比如对于大的矩阵乘法需要多channel的卷积运算,这些运算需要的数据不多,但是计算很复杂
    • Memory-bound:速度瓶颈在于HBM的读取速度,即算力等待数据。这样的操作包含两类,这些操作都是需要数据很多,但是计算相对简单:
      • 按位操作:Relu、Dropout
      • 规约操作:sum、softmax
    • 对于Memory-bound的优化一般是进行fusion融合操作,通过不对中间结果缓存来减少GBM的访问。但中间结果在方向传播时是有用的,这样不保存的话只能在反向传播时重新计算
  4. memory层级:需要尽可能增加芯片内(SRAM)的访存、减少芯片外的HBM的访存
    请添加图片描述
  5. Flash Attention原理:避免attention matrix从HBM的读写
    • 通过分块计算,融合attention内的操作,不缓存中间结果到HBM,从而加快速度
    • 反向传播时,重新计算中间结果,以此来解决不缓存后梯度无法计算的问题
    • 具体实现:
      请添加图片描述
      请添加图片描述
      请添加图片描述
      如上图, K T K^T KT、V不变,分别读取Q的前、中间、后两行,计算出O的前、中间、后两行(非最终结果);然后取 K T K^T KT和V的后三列、后三行,并分别和Q的前、中、后两行进行计算,附加到O上,得到O的最终结果。

      上述去掉了softmax。Softmax是按行进行的,如下图所示,要让矩阵分块对attention多步进行融合计算得以进行的前提就是要解决softmax的分块问题:
      请添加图片描述
  6. softmax的分块计算:如下图2所示,softmax是可以分块计算的,但是就是不如上面的简单,softmax的分块计算需要额外保存几个变量(即 m ( x 1 ) , m ( x 2 ) , p ( x 1 ) , p ( x 2 ) , l ( x 1 ) , l ( x 2 ) m(x^1), m(x^2), p(x^1), p(x^2), l(x^1), l(x^2) m(x1),m(x2),p(x1),p(x2),l(x1),l(x2))、以及在分块的合并时需要增加一些计算量。但这些计算量相对于减少的IO时间都是非常划算的
    原softmax:
    请添加图片描述
    分块版softmax:
    请添加图片描述
    伪代码:
    请添加图片描述
  7. FlashAttention2的大致思想与FlashAttention1类似,只不过增加了以下一些工程上的优化:
    • 减少了非矩阵乘法计算,可利用TensorCore加速
    • 调整了内外循环,Q为外层循环、KV为内层循环。进一步减少了HBM读写
    • 如果一个block处于矩阵上三角部分(也就是被mask掉的部分),就不进行attention计算,进一步减少计算量

三、FlashAttention更详细的细节剖析

可见这篇文章,写得很好,一篇就足够了。
笔记很详细,见实验室-FlashAttention修改-flashAttention原理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值