轻松读懂FlashAttention 下

若在阅读过程中有些知识点存在盲区,可以回到如何优雅的谈论大模型重新阅读。另外斯坦福2024人工智能报告解读为通识性读物。若对于如果构建生成级别的AI架构则可以关注AI架构设计。技术宅麻烦死磕LLM背后的基础模型

FlashAttention属于AI加速器,要读懂它需要先具备Transformer的背景知识以及注意力机制,最后才到FlashAttention。随着大模型优化技术的层出不穷,里面的kernel fusion技术将会越来越频繁的被提及,例如在Mamba中也被用于加速。因此借着FlashAttention的这个机会更加深入的了解下GPU。 

前期回顾

传统注意力机制的Q、K、V的运算都是先搬到HBM(读写速度1.5T/s,容量40G再进行运算。而FlashAttention就是想利用SRAM的高速读写(19T/s)来加速,但是SRAM的容量只有20M,在面对起步都是4k+的上下文序列,如何进行切块且采用对应的流式计算则是FlashAttention面临的最大挑战。

动画回放

读到此处,其实FlashAttention也没有那么复杂。记得以前有道编程面试题目,有100亿个乱序的浮点数(文件格式存储),如何高效的编程才能求出其平均值,最大值,最小值和排在18%位的数值。其实这里都是涉及到算法的并行化和分布式化重建与设计。

下面两段视频,先让大家直观感受有FlashAttention和没有之间的区别。

标准过程(无FlashAttention)

标准Attention Algorithm(不带FA)

加入FlashAttention机制

The Flash Attention Algorithm

算法解析

通过上节的分析,FlashAttention还是为了解决下图中QKT的softmax的效率。下图仅仅是6个Token的序列,若序列长度达到4k+以上,这个矩阵的softmax的运算效率就成为很关键的因素。

如下左图所示,需要针对每一行采用softmax归一化,而softmax的计算需要将这一行的所有元素变换累加成为分母。

FlashAttention采用下面的算法进行改写计算过程,构造了类似“分而治之”的思想来达到和softmax相同的运算结果。

比起直接的计算,这种算法需要额外的存储m(x)和l(x)。大白话的意思就是将大的矩阵切块加载到SRAM,然后计算这个分块得到m和l值。利用上一轮的m和l,再次读入新块进行计算,最后更新m和l。在明白了这个公式之后,来看看FlashAttetion的算法:

不要慌张,一行一行的往下解读便是。首先QKV的数据在运算的时候会被加载到GPU的HBM(读取速度1.5T/s),一般而言HBM的容量在40G+。在这里假设SRAM的内存容量为M。

假定块的列大小为𝐵𝑐=ceil(𝑀/4𝑑),块的行大小为𝐵𝑟=𝑚𝑖𝑛(𝑀/4𝑑,𝑑)。最小值函数min的目的是防止块大小𝐵𝑟×𝐵𝑐>𝑀/4。下面这幅图很形象,麻将桌四人凑。

为什么是ceil(M/4d)?因为每个Q、K和V都是d维向量,外加输出O也是d维向量,所以就是M/4d份数据能够被满满当当的加载至SRAM。

玩具示例:假设M=1000,d=5。在此示例中,每个块大小为1000/(4*5)=50。因此将一次加载50个q、k、v、o向量,确保减少HBM和SRAM之间的互相读/写。

序列长度为N,维度为d,要运算出的O就是最终结果为N*d。把结果矩阵O初始化为零,后面逐步累加中间结果。

还记得之前softmax的矩阵么,它是按行求softmax,所以作为中间变量的m和l都是N维的。m(x) 代表某行最大的值,l(x)代表这一行exp的累加和

这里很多的矩阵,细心的你一定会发现。其实就是将Q按照Br为系数,按行切分将Tr块,而KV也是按照行,将其分解Tc块。

根据前面的计算,结果矩阵O需要切分成𝐵𝑟×𝑑的块来存放中间结果。长度为N的l和m也要切分成𝐵𝑟个元素的块,用于存放这些行当前的指数累加值和当前最大值。

小编ps切块的示意图一目了然。例如上图,Token的维度为4096。矩阵Q被切成4行/块,KV被切分成3行/块。虽然按照矩阵乘法可以算出这Q的这个块和K的这个块的乘积QKT,但是还是无法计算softmax,因为要等其他所有的值都算出来。

第5到16其实也很好理解,就是将分块的数据以及上一轮的(Oi-1,mi-1,li-1)从HBM读入到SRAM,然后在SRAM中进行局部的QKT运算,以及按照算法同步更新m和l的数值。最后将更新的数据(Oi,mi,li)一并写回到HBM,如此循环迭代直至结束。

大白话的讲就是在on-chip(对应SRAM)的运算中存储中间状态,然后利用改造softmax算法进行流式计算,这样节省了很多SRAM到HBM的数据来来回回,进而节省时间。

空间复杂度:在HBM中分配了QKVO(Nxd)、l&m(N) ,即4*N*d + 2*N。d是一个常数并且通常比N小得多(例如通常 d={32, 64, 128}, N={1024, …, 100k}),因此得到O(N)空间。这有助于将Transformer扩展到64k序列长度,当然还需要借鉴一些其他的办法,例如ALiBi。

后话

FlashAttention是一种无需任何近似即可加速注意力并减少内存占用的新算法。许多组织和研究实验室采用FlashAttention来加速他们的训练和推理。尽管FlashAttention在发布时已经比优化基线快 2-4 倍,但它仍然有相当大的空间。FlashAttention仍然不如优化矩阵乘法 (GEMM) 运算快,仅达到理论最大 FLOPs/s的25-40%。

FlashAttention2将会使FlashAttention变得更好。FlashAttention2完全从头开始重写,使用Nvidia的CUTLASS 3.x 及其核心库CuTe的原语,其速度比之前的版本快约2倍,在A100 GPU (FP16/BF16) 上达到230 TFLOPs/s。那么下期将介绍FlashAttention2,它突破了FlashAttention的一些瓶颈,以及使用更好的并行性和工作分区进行显着的加速。

  • 27
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值