一、引言
论文: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
作者: Stanford University
代码: FlashAttention
特点: 该方法提出将Q、K、V拆分为若干小块,使执行注意力时不需要频繁进行读写操作,而是每个小块只进行一次读写,从而提升注意力的执行速度。
⚠️ 在学习该方法前,建议补充Attention的相关知识。
二、详情
GPU中SRAM和HBM的计算和存储能力如下图:
可见,SRAM计算能力强(17TB/s),HBM的存储容量大(40GB)。因此,GPU的运算通常在SRAM上进行,如果运算结果的内存占用太大,系统会把运算结果先写入HBM,然后从HBM读出来再在SRAM上进行下一步的运算。
于是,我们就得到原始Attention的执行过程:
其中,Q、K、V分别是Query、Key、Value矩阵,S是相似度矩阵,P是权重矩阵,O是输出矩阵。
这里没写除以 d k \sqrt{d_k} dk的操作,不过无伤大雅,因为它对运算的影响并不大。
可见,计算S、P、O时都要进行读取,计算完成后也都要进行写入。然而,运算速度领先于读写速度导致SRAM运算完了要等数据过来才能进行下一步运算,这就拖慢了整体的速度。
2.1 拆分
FlashAttention提出将Q、K、V拆分成若干小块,这样每个小块的S、P矩阵不至于太大到需要写入HBM中,这样就能只在最开始读取Q、K、V、O(之前的运算结果),在SRAM中完成所有运算后,再将新的O写入HBM。
如果没有SoftMax操作,该过程很容易实现,如下图:
分别循环Q和K、V的小块,循环结果求和就是我们所有期望的O。但是,SoftMax阻碍了它的实现,回顾原始SoftMax公式:
s
o
f
t
m
a
x
(
s
)
j
=
e
s
j
∑
k
=
1
N
e
s
k
softmax(\boldsymbol{s})_j=\frac{e^{s_j}}{\sum_{k=1}^{N}e^{s_k}}
softmax(s)j=∑k=1Neskesj
可见,它要把相似度矩阵S的每一行转为一个概率分布。但是分块策略无法一次性获得完整的S中的行,于是FlashAttention在SoftMax中引入了
m
(
s
)
m(\boldsymbol{s})
m(s),新的SoftMax公式如下:
s
o
f
t
m
a
x
(
s
)
i
=
e
s
i
−
m
(
s
)
∑
j
=
1
N
e
s
j
−
m
(
s
)
=
f
i
l
(
s
)
softmax(\boldsymbol{s})_i=\frac{e^{s_i-m(\boldsymbol{s})}}{\sum_{j=1}^{N}e^{s_j-m(\boldsymbol{s})}}=\frac{f_i}{l(\boldsymbol{s})}
softmax(s)i=∑j=1Nesj−m(s)esi−m(s)=l(s)fi
其中,最大值
m
(
s
)
=
max
i
s
i
m(\boldsymbol{s})=\max_i s_i
m(s)=maxisi,指数和
l
(
s
)
=
∑
i
f
i
l(\boldsymbol{s})=\sum_i f_i
l(s)=∑ifi。事实上,该操作不会影响SoftMax的结果,如下:
s
o
f
t
m
a
x
(
[
1
,
2
,
3
,
10
]
)
=
[
e
1
e
1
+
e
2
+
e
3
+
e
10
,
e
2
e
1
+
e
2
+
e
3
+
e
10
,
e
3
e
1
+
e
2
+
e
3
+
e
10
,
e
10
e
1
+
e
2
+
e
3
+
e
10
]
=
[
e
1
−
10
e
1
−
10
+
e
2
−
10
+
e
3
−
10
+
e
10
−
10
,
e
2
−
10
e
1
−
10
+
e
2
−
10
+
e
3
−
10
+
e
10
−
10
,
e
3
−
10
e
1
−
10
+
e
2
−
10
+
e
3
−
10
+
e
10
−
10
,
e
10
−
10
e
1
−
10
+
e
2
−
10
+
e
3
−
10
+
e
10
−
10
]
softmax([1,2,3,10])=[\frac{e^{1}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{2}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{3}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{10}}{e^{1}+e^{2}+e^{3}+e^{10}}]\\=[\frac{e^{1-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{2-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{3-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{10-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}]
softmax([1,2,3,10])=[e1+e2+e3+e10e1,e1+e2+e3+e10e2,e1+e2+e3+e10e3,e1+e2+e3+e10e10]=[e1−10+e2−10+e3−10+e10−10e1−10,e1−10+e2−10+e3−10+e10−10e2−10,e1−10+e2−10+e3−10+e10−10e3−10,e1−10+e2−10+e3−10+e10−10e10−10]
可见,上下同乘 e 10 e^{10} e10即可还原为原公式。
此时,我们分
T
r
=
2
T_r=2
Tr=2块分别计算上述SoftMax,有:
s
o
f
t
m
a
x
(
[
1
,
2
]
)
=
[
e
1
−
m
1
e
1
−
m
1
+
e
2
−
m
1
,
e
2
−
m
1
e
1
−
m
1
+
e
2
−
m
1
]
=
[
f
1
l
1
,
f
2
l
1
]
,
m
1
=
2
s
o
f
t
m
a
x
(
[
3
,
10
]
)
=
[
e
3
−
m
2
e
3
−
m
2
+
e
10
−
m
2
,
e
10
−
m
2
e
3
−
m
2
+
e
10
−
m
2
]
=
[
f
3
l
2
,
f
4
l
2
]
,
m
2
=
10
softmax([1,2])=[\frac{e^{1-m_1}}{e^{1-m_1}+e^{2-m_1}},\frac{e^{2-m_1}}{e^{1-m_1}+e^{2-m_1}}]=[\frac{f_1}{l_1},\frac{f_{2}}{l_1}],m_1=2\\ softmax([3,10])=[\frac{e^{3-m_2}}{e^{3-m_2}+e^{10-m_2}},\frac{e^{10-m_2}}{e^{3-m_2}+e^{10-m_2}}]=[\frac{f_3}{l_2},\frac{f_4}{l_2}],m_2=10
softmax([1,2])=[e1−m1+e2−m1e1−m1,e1−m1+e2−m1e2−m1]=[l1f1,l1f2],m1=2softmax([3,10])=[e3−m2+e10−m2e3−m2,e3−m2+e10−m2e10−m2]=[l2f3,l2f4],m2=10
其中,每个小块里减去的是当前块的最大值,记为 m i m_i mi;当前块的分子,记为 p i \boldsymbol{p}_i pi(是多个 f i f_i fi组成的向量);当前块的分母指数和,记为 l i l_i li。对应地,当前块的输出 p i / l i \boldsymbol{p}_i/l_i pi/li,记为 o \boldsymbol{o} o。
在不同块的遍历计算过程中,我们可以不断更新最大值 m ( s ) m(\boldsymbol{s}) m(s)(初始为负无穷)、指数和 l ( s ) l(\boldsymbol{s}) l(s)(初始为0)。
对于
m
(
s
)
m(\boldsymbol{s})
m(s),更新公式为
m
(
s
)
n
e
w
=
max
(
m
(
s
)
,
m
i
)
m(\boldsymbol{s})^{new}=\max(m(\boldsymbol{s}),m_i)
m(s)new=max(m(s),mi)。
对于
l
(
s
)
l(\boldsymbol{s})
l(s),更新公式为
l
(
s
)
n
e
w
=
e
m
(
s
)
−
m
(
s
)
n
e
w
×
l
(
s
)
+
e
m
i
−
m
(
s
)
n
e
w
×
l
i
l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_i-m(\boldsymbol{s})^{new}}\times l_i
l(s)new=em(s)−m(s)new×l(s)+emi−m(s)new×li。
在第一块中,
- m ( s ) n e w = max ( − inf , m 1 ) = 2 m(\boldsymbol{s})^{new}=\max(-\inf,m_1)=2 m(s)new=max(−inf,m1)=2
- l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 1 − m ( s ) n e w × l 1 = e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_1-m(\boldsymbol{s})^{new}}\times l_1=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2}) l(s)new=em(s)−m(s)new×l(s)+em1−m(s)new×l1=e−inf−2×0+e2−2×(e1−2+e2−2)
- 令 m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)←m(s)new, l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)←l(s)new
在第二块中,
- m ( s ) n e w = max ( 2 , m 2 ) = 10 m(\boldsymbol{s})^{new}=\max(2,m_2)=10 m(s)new=max(2,m2)=10
- l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 2 − m ( s ) n e w × l 2 l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_2-m(\boldsymbol{s})^{new}}\times l_2 l(s)new=em(s)−m(s)new×l(s)+em2−m(s)new×l2
= e 2 − 10 × ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × ( e 3 − 10 + e 10 − 10 ) = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 =e^{2-10}\times(e^{1-2}+e^{2-2})+e^{10-10}\times(e^{3-10}+e^{10-10})=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} =e2−10×(e1−2+e2−2)+e10−10×(e3−10+e10−10)=e1−10+e2−10+e3−10+e10−10
可见,最后的输出结果 m ( s ) m(\boldsymbol{s}) m(s)和 l ( s ) l(\boldsymbol{s}) l(s)已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) softmax([1,2,3,10]) softmax([1,2,3,10])中的一致。
m ( s ) m(\boldsymbol{s}) m(s)的更新公式能使 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new始终为当前行的最大值, l ( s ) l(\boldsymbol{s}) l(s)的更新公式能使 l ( s ) n e w l(\boldsymbol{s})^{new} l(s)new的指数项始终减的是 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new。
同样地,在遍历过程中,我们也可以根据新的 m ( s ) m(\boldsymbol{s}) m(s)和 l ( s ) l(\boldsymbol{s}) l(s)计算和更新当前的 o \boldsymbol{o} o(初始为0向量)。
对于
o
\boldsymbol{o}
o,更新公式为
o
n
e
w
=
l
(
s
)
×
e
m
(
s
)
−
m
(
s
)
n
e
w
×
o
+
e
m
i
−
m
(
s
)
n
e
w
×
p
i
×
V
i
l
(
s
)
n
e
w
\boldsymbol{o}^{new}=\frac{l(\boldsymbol{s})\times e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times \boldsymbol{o}+e^{m_i-m(\boldsymbol{s})^{new}}\times \boldsymbol{p}_i\times\boldsymbol{V}_i}{l(\boldsymbol{s})^{new}}
onew=l(s)newl(s)×em(s)−m(s)new×o+emi−m(s)new×pi×Vi
其中, p i = [ f i ∗ B r , ⋯ , f ( i + 1 ) ∗ B r ] \boldsymbol{p}_i=[f_{i*Br},\cdots,f_{(i+1)*B_r}] pi=[fi∗Br,⋯,f(i+1)∗Br], V i \boldsymbol{V}_i Vi为V矩阵的第 i i i块。
我们假设 V = [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] , [ 7 , 8 ] ] \boldsymbol{V}=[[1,2],[3,4],[5,6],[7,8]] V=[[1,2],[3,4],[5,6],[7,8]],则有
在第一块中,
- m ( s ) n e w = 2 m(\boldsymbol{s})^{new}=2 m(s)new=2
- l ( s ) n e w = e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = e 1 − 2 + e 2 − 2 l(\boldsymbol{s})^{new}=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})=e^{1-2}+e^{2-2} l(s)new=e−inf−2×0+e2−2×(e1−2+e2−2)=e1−2+e2−2
- o n e w = 0 × e − inf − 2 × 0 + e 2 − 2 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) \boldsymbol{o}^{new}=\frac{0\times e^{-\inf-2}\times 0+e^{2-2}\times [e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})}=\frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})} onew=e−inf−2×0+e2−2×(e1−2+e2−2)0×e−inf−2×0+e2−2×[e1−2,e2−2]×[1324]=(e1−2+e2−2)[e1−2,e2−2]×[1324]
- 令 m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)←m(s)new, l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)←l(s)new, o ← o n e w \boldsymbol{o}\leftarrow \boldsymbol{o}^{new} o←onew
在第二块中,
- m ( s ) n e w = 10 m(\boldsymbol{s})^{new}=10 m(s)new=10
- l ( s ) n e w = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 l(\boldsymbol{s})^{new}=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} l(s)new=e1−10+e2−10+e3−10+e10−10
- o n e w = ( e 1 − 2 + e 2 − 2 ) × e 2 − 10 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 = [ e 1 − 10 , e 2 − 10 ] × [ 1 2 3 4 ] + [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 \boldsymbol{o}^{new}=\frac{(e^{1-2}+e^{2-2})\times e^{2-10}\times \frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})}+e^{10-10}\times [e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}\\=\frac{[e^{1-10},e^{2-10}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}+[e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}} onew=e1−10+e2−10+e3−10+e10−10(e1−2+e2−2)×e2−10×(e1−2+e2−2)[e1−2,e2−2]×[1324]+e10−10×[e3−10,e10−10]×[5768]=e1−10+e2−10+e3−10+e10−10[e1−10,e2−10]×[1324]+[e3−10,e10−10]×[5768]
可见,最后的结果已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) × V softmax([1,2,3,10])\times\boldsymbol{V} softmax([1,2,3,10])×V一致。
o \boldsymbol{o} o的更新公式能使各块分子指数项上减去最新的 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new,并使各块的最新的指数和合并。
致谢:
本博客仅做记录使用,无任何商业用途,参考内容如下:
Flash Attention 为什么那么快?原理讲解
Flash Attention论文解读