【多模态大模型】FlashAttention in NeurIPS 2022

一、引言

论文: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
作者: Stanford University
代码: FlashAttention
特点: 该方法提出将Q、K、V拆分为若干小块,使执行注意力时不需要频繁进行读写操作,而是每个小块只进行一次读写,从而提升注意力的执行速度。

⚠️ 在学习该方法前,建议补充Attention的相关知识。

二、详情

GPU中SRAM和HBM的计算和存储能力如下图:

可见,SRAM计算能力强(17TB/s),HBM的存储容量大(40GB)。因此,GPU的运算通常在SRAM上进行,如果运算结果的内存占用太大,系统会把运算结果先写入HBM,然后从HBM读出来再在SRAM上进行下一步的运算。

于是,我们就得到原始Attention的执行过程:

其中,Q、K、V分别是Query、Key、Value矩阵,S是相似度矩阵,P是权重矩阵,O是输出矩阵。

这里没写除以 d k \sqrt{d_k} dk 的操作,不过无伤大雅,因为它对运算的影响并不大。

可见,计算S、P、O时都要进行读取,计算完成后也都要进行写入。然而,运算速度领先于读写速度导致SRAM运算完了要等数据过来才能进行下一步运算,这就拖慢了整体的速度。

2.1 拆分

FlashAttention提出将Q、K、V拆分成若干小块,这样每个小块的S、P矩阵不至于太大到需要写入HBM中,这样就能只在最开始读取Q、K、V、O(之前的运算结果),在SRAM中完成所有运算后,再将新的O写入HBM。

如果没有SoftMax操作,该过程很容易实现,如下图:

分别循环Q和K、V的小块,循环结果求和就是我们所有期望的O。但是,SoftMax阻碍了它的实现,回顾原始SoftMax公式:
s o f t m a x ( s ) j = e s j ∑ k = 1 N e s k softmax(\boldsymbol{s})_j=\frac{e^{s_j}}{\sum_{k=1}^{N}e^{s_k}} softmax(s)j=k=1Neskesj

可见,它要把相似度矩阵S的每一行转为一个概率分布。但是分块策略无法一次性获得完整的S中的行,于是FlashAttention在SoftMax中引入了 m ( s ) m(\boldsymbol{s}) m(s),新的SoftMax公式如下:
s o f t m a x ( s ) i = e s i − m ( s ) ∑ j = 1 N e s j − m ( s ) = f i l ( s ) softmax(\boldsymbol{s})_i=\frac{e^{s_i-m(\boldsymbol{s})}}{\sum_{j=1}^{N}e^{s_j-m(\boldsymbol{s})}}=\frac{f_i}{l(\boldsymbol{s})} softmax(s)i=j=1Nesjm(s)esim(s)=l(s)fi

其中,最大值 m ( s ) = max ⁡ i s i m(\boldsymbol{s})=\max_i s_i m(s)=maxisi,指数和 l ( s ) = ∑ i f i l(\boldsymbol{s})=\sum_i f_i l(s)=ifi。事实上,该操作不会影响SoftMax的结果,如下:
s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) = [ e 1 e 1 + e 2 + e 3 + e 10 , e 2 e 1 + e 2 + e 3 + e 10 , e 3 e 1 + e 2 + e 3 + e 10 , e 10 e 1 + e 2 + e 3 + e 10 ] = [ e 1 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 2 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 3 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 10 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 ] softmax([1,2,3,10])=[\frac{e^{1}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{2}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{3}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{10}}{e^{1}+e^{2}+e^{3}+e^{10}}]\\=[\frac{e^{1-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{2-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{3-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{10-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}] softmax([1,2,3,10])=[e1+e2+e3+e10e1,e1+e2+e3+e10e2,e1+e2+e3+e10e3,e1+e2+e3+e10e10]=[e110+e210+e310+e1010e110,e110+e210+e310+e1010e210,e110+e210+e310+e1010e310,e110+e210+e310+e1010e1010]

可见,上下同乘 e 10 e^{10} e10即可还原为原公式。

此时,我们分 T r = 2 T_r=2 Tr=2块分别计算上述SoftMax,有:
s o f t m a x ( [ 1 , 2 ] ) = [ e 1 − m 1 e 1 − m 1 + e 2 − m 1 , e 2 − m 1 e 1 − m 1 + e 2 − m 1 ] = [ f 1 l 1 , f 2 l 1 ] , m 1 = 2 s o f t m a x ( [ 3 , 10 ] ) = [ e 3 − m 2 e 3 − m 2 + e 10 − m 2 , e 10 − m 2 e 3 − m 2 + e 10 − m 2 ] = [ f 3 l 2 , f 4 l 2 ] , m 2 = 10 softmax([1,2])=[\frac{e^{1-m_1}}{e^{1-m_1}+e^{2-m_1}},\frac{e^{2-m_1}}{e^{1-m_1}+e^{2-m_1}}]=[\frac{f_1}{l_1},\frac{f_{2}}{l_1}],m_1=2\\ softmax([3,10])=[\frac{e^{3-m_2}}{e^{3-m_2}+e^{10-m_2}},\frac{e^{10-m_2}}{e^{3-m_2}+e^{10-m_2}}]=[\frac{f_3}{l_2},\frac{f_4}{l_2}],m_2=10 softmax([1,2])=[e1m1+e2m1e1m1,e1m1+e2m1e2m1]=[l1f1,l1f2],m1=2softmax([3,10])=[e3m2+e10m2e3m2,e3m2+e10m2e10m2]=[l2f3,l2f4],m2=10

其中,每个小块里减去的是当前块的最大值,记为 m i m_i mi;当前块的分子,记为 p i \boldsymbol{p}_i pi(是多个 f i f_i fi组成的向量);当前块的分母指数和,记为 l i l_i li。对应地,当前块的输出 p i / l i \boldsymbol{p}_i/l_i pi/li,记为 o \boldsymbol{o} o

在不同块的遍历计算过程中,我们可以不断更新最大值 m ( s ) m(\boldsymbol{s}) m(s)(初始为负无穷)、指数和 l ( s ) l(\boldsymbol{s}) l(s)(初始为0)。

对于 m ( s ) m(\boldsymbol{s}) m(s),更新公式为 m ( s ) n e w = max ⁡ ( m ( s ) , m i ) m(\boldsymbol{s})^{new}=\max(m(\boldsymbol{s}),m_i) m(s)new=max(m(s),mi)
对于 l ( s ) l(\boldsymbol{s}) l(s),更新公式为 l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m i − m ( s ) n e w × l i l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_i-m(\boldsymbol{s})^{new}}\times l_i l(s)new=em(s)m(s)new×l(s)+emim(s)new×li

在第一块中,

  • m ( s ) n e w = max ⁡ ( − inf ⁡ , m 1 ) = 2 m(\boldsymbol{s})^{new}=\max(-\inf,m_1)=2 m(s)new=max(inf,m1)=2
  • l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 1 − m ( s ) n e w × l 1 = e − inf ⁡ − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_1-m(\boldsymbol{s})^{new}}\times l_1=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2}) l(s)new=em(s)m(s)new×l(s)+em1m(s)new×l1=einf2×0+e22×(e12+e22)
  • m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)m(s)new l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)l(s)new

在第二块中,

  • m ( s ) n e w = max ⁡ ( 2 , m 2 ) = 10 m(\boldsymbol{s})^{new}=\max(2,m_2)=10 m(s)new=max(2,m2)=10
  • l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 2 − m ( s ) n e w × l 2 l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_2-m(\boldsymbol{s})^{new}}\times l_2 l(s)new=em(s)m(s)new×l(s)+em2m(s)new×l2
    = e 2 − 10 × ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × ( e 3 − 10 + e 10 − 10 ) = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 =e^{2-10}\times(e^{1-2}+e^{2-2})+e^{10-10}\times(e^{3-10}+e^{10-10})=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} =e210×(e12+e22)+e1010×(e310+e1010)=e110+e210+e310+e1010

可见,最后的输出结果 m ( s ) m(\boldsymbol{s}) m(s) l ( s ) l(\boldsymbol{s}) l(s)已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) softmax([1,2,3,10]) softmax([1,2,3,10])中的一致。

m ( s ) m(\boldsymbol{s}) m(s)的更新公式能使 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new始终为当前行的最大值, l ( s ) l(\boldsymbol{s}) l(s)的更新公式能使 l ( s ) n e w l(\boldsymbol{s})^{new} l(s)new的指数项始终减的是 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new

同样地,在遍历过程中,我们也可以根据新的 m ( s ) m(\boldsymbol{s}) m(s) l ( s ) l(\boldsymbol{s}) l(s)计算和更新当前的 o \boldsymbol{o} o(初始为0向量)。

对于 o \boldsymbol{o} o,更新公式为
o n e w = l ( s ) × e m ( s ) − m ( s ) n e w × o + e m i − m ( s ) n e w × p i × V i l ( s ) n e w \boldsymbol{o}^{new}=\frac{l(\boldsymbol{s})\times e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times \boldsymbol{o}+e^{m_i-m(\boldsymbol{s})^{new}}\times \boldsymbol{p}_i\times\boldsymbol{V}_i}{l(\boldsymbol{s})^{new}} onew=l(s)newl(s)×em(s)m(s)new×o+emim(s)new×pi×Vi

其中, p i = [ f i ∗ B r , ⋯   , f ( i + 1 ) ∗ B r ] \boldsymbol{p}_i=[f_{i*Br},\cdots,f_{(i+1)*B_r}] pi=[fiBr,,f(i+1)Br] V i \boldsymbol{V}_i Vi为V矩阵的第 i i i块。

我们假设 V = [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] , [ 7 , 8 ] ] \boldsymbol{V}=[[1,2],[3,4],[5,6],[7,8]] V=[[1,2],[3,4],[5,6],[7,8]],则有

在第一块中,

  • m ( s ) n e w = 2 m(\boldsymbol{s})^{new}=2 m(s)new=2
  • l ( s ) n e w = e − inf ⁡ − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = e 1 − 2 + e 2 − 2 l(\boldsymbol{s})^{new}=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})=e^{1-2}+e^{2-2} l(s)new=einf2×0+e22×(e12+e22)=e12+e22
  • o n e w = 0 × e − inf ⁡ − 2 × 0 + e 2 − 2 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] e − inf ⁡ − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) \boldsymbol{o}^{new}=\frac{0\times e^{-\inf-2}\times 0+e^{2-2}\times [e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})}=\frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})} onew=einf2×0+e22×(e12+e22)0×einf2×0+e22×[e12,e22]×[1324]=(e12+e22)[e12,e22]×[1324]
  • m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)m(s)new l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)l(s)new o ← o n e w \boldsymbol{o}\leftarrow \boldsymbol{o}^{new} oonew

在第二块中,

  • m ( s ) n e w = 10 m(\boldsymbol{s})^{new}=10 m(s)new=10
  • l ( s ) n e w = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 l(\boldsymbol{s})^{new}=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} l(s)new=e110+e210+e310+e1010
  • o n e w = ( e 1 − 2 + e 2 − 2 ) × e 2 − 10 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 = [ e 1 − 10 , e 2 − 10 ] × [ 1 2 3 4 ] + [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 \boldsymbol{o}^{new}=\frac{(e^{1-2}+e^{2-2})\times e^{2-10}\times \frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})}+e^{10-10}\times [e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}\\=\frac{[e^{1-10},e^{2-10}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}+[e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}} onew=e110+e210+e310+e1010(e12+e22)×e210×(e12+e22)[e12,e22]×[1324]+e1010×[e310,e1010]×[5768]=e110+e210+e310+e1010[e110,e210]×[1324]+[e310,e1010]×[5768]

可见,最后的结果已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) × V softmax([1,2,3,10])\times\boldsymbol{V} softmax([1,2,3,10])×V一致。

o \boldsymbol{o} o的更新公式能使各块分子指数项上减去最新的 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new,并使各块的最新的指数和合并。

致谢:

本博客仅做记录使用,无任何商业用途,参考内容如下:
Flash Attention 为什么那么快?原理讲解
Flash Attention论文解读

  • 17
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值