Flash-Attention 1&2 论文理解

Flash-Attention 1&2 论文理解

https://github.com/Dao-AILab/flash-attention
在这里插入图片描述
Self-Attention标准实现:
在这里插入图片描述
在这里插入图片描述

论文中 softmax ( S ) \text{softmax}(\mathbf S) softmax(S)表示对矩阵 S \mathbf S S行进行操作,且需要减去行最大值。
论文思想类似矩阵乘法优化思路,即对矩阵进行分块,通过迭代法进行计算,迭代带来的问题是,迭代过程中softmax操作没法准确获取矩阵S每行的最大值(只能拿到当前最大值)。因此需要在每次迭代后对结果 O \bf O O进行修正(这是Flash-Attention 1 的想法, Flash-Attention2改为了在最后一步修正,因为当计算到每行最后一块时,行最大值就是准确的了)。具体步骤如下(假设只分成2步):
在这里插入图片描述
对比下不进行迭代计算(但依然分块):
在这里插入图片描述
主要看行最大值 m \bm m m、指数和 l \bm l l的更新过程,比如计算 l ( 2 ) {\bm l}^{(2)} l(2),在拿到新的最大值 m ( 2 ) {\bm m}^{(2)} m(2)后,先对 l ( 1 ) {\bm l}^{(1)} l(1)进行了修正(如果最大值没变 m ( 2 ) = m ( 1 ) {\bm m}^{(2)}={\bm m}^{(1)} m(2)=m(1),则系数等于1),然后加上当前块的指数和,从而得到累计到当前块的指数和。

仔细思考可以发现迭代过程中 P ~ \bf \tilde P P~的缩放系数(对应softmax函数中除以指数和的部分)是多余的,因为在计算 O ( 2 ) \bf O^{(2)} O(2)过程中,将 O ( 1 ) \bf O^{(1)} O(1)代入后,之前除以的系数后面又被乘了回来(这里怀疑公式中 d i a g ( l ( 1 ) / l ( 2 ) ) − 1 diag(\bm l^{(1)} /\bm l^{(2)})^{-1} diag(l(1)/l(2))1是不是写反了?应该为 d i a g ( l ( 2 ) / l ( 1 ) ) − 1 diag(\bm l^{(2)} /\bm l^{(1)})^{-1} diag(l(2)/l(1))1 ?)。因此在Flash-Attention 2中改为:
在这里插入图片描述
这里 P ~ ( 2 ) \bf \tilde P^{(2)} P~(2)计算应该是多余的。

采用迭代方式目的是充分利用GPU中的共享内存,即第一张图中左侧部分,SRAM速度快,但存储小,通过迭代方式可以一次只计算大矩阵的中一小块,将所有操作限制到SRAM中,从而实现加速。

最后贴一下Flash-Attention 2的算法步骤,注意最后一步的修正:
在这里插入图片描述
可以对照这两张图进行理解:
在这里插入图片描述

在这里插入图片描述

另外,Flash-Attention 2 相比 Flash-Attention 1还改了下循环变量,2中将Q作为外层循环,1中将Q放到内层循环,显然2更加合理,因为给定 Q i \bf Q_{i} Qi后,在内层循环中,可以将O完整的计算出来,而1需要将O的中间结果写回显存,在后续的外层循环中再读出来。对比1的实现如下:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值