flash attention: Fast and Memory-Efficient Exact Attention with IO-Awareness

FlashAttention 学习笔记,不会按照文章结构走,理解为王


前言

目前transformer 相关应用非常广泛,因此分享一篇关于flash attention的文章。这里为什么先分享flash attention?
首先,之前的 attention优化都是基于计算和稀疏性,这篇文章直接从硬件构架角度思考,减少硬件的开销。
其次,本文的效果好,直接从 N 2 N^2 N2的复杂度降为N,在不改变attention 结构的情下加速显然,甚至因为能增加输入的长度,使得效果有提升,所以在这里做一篇分享,欢迎交流。
在这里插入图片描述


1、transformer 基础

网上关于transformer的讲述非常多也讲的很好,我这里不进行赘述,祭出一个公式,后面用的比较多
a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T ( d k ) ) V attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt(d_k)})V attention(Q,K,V)=softmax(( dk)QKT)V, 其中 Q , K , V ∈ N × d Q,K,V ∈ N×d Q,K,VN×d
N:sequence length
d:head dimension

2、GPU memory 结构

FlashAttention 有个最大的优点就是基于硬件讨论。作者在文中指出,文章的出发点就是减少IO的访问(access, 论文中作者后面的复杂度都指 access) 。下图是论文作者分享的一个比较重要的memory 构架图其中标注了硬件名称和bandwith 我会详细阐明。
在这里插入图片描述
这里推荐链接:https://developer.nvidia.com/zh-cn/blog/nvidia-ampere-architecture-in-depth/
讲明了英伟达A100的构架,都是干货

2.1、Memory hierarchy

(1)GPU SRAM
sram 一直是static random access memory,它不是寄存器也不能当做寄存器(部分博主视频讲解会把它认为是寄存器)。准确的理解是片上高速缓存,不会直接参与计算,用于存放数据。由于访问sram的delay非常小,因此将其存放反复访问的数据。寄存器的耗时delay是1cycle,sram的访问delay是2~5cycle级别。
如果在DSP构架中,SRAM是TCM(Tightly-Coupled Memory)可以用于存放指令或数据。因为几乎与内核同频,因此叫做tightly-coupled。在GPU构架中SRAM也可以做L1/L2缓存/share memory。要注意的是L2cache 和 share memory 访问速度会慢一些,访问delay的数量级大概十几个cycle或者share memory会更慢一些。

(2)GPU HBM
HBM 是 high bandwidth memory,它是由dram做的。dram全名是Dynamic Random Access Memory。访问速
会更慢,访问delay大概在几十个cycle,我们常说的 DDR就是由DRAM构成,论文这里HBM也是由DRAM构成,只是NVIDIA 加入了自己的技术TSV(Through-Silicon Vias, TSV),TSV 是贯穿每个 DRAM 芯片的垂直导电通道,使得数据可以在不同的 DRAM 层之间传输。TSV 技术使得多个 DRAM 芯片可以堆叠在一起,形成一个紧凑的内存模块。
TSV 提供了高速数据传输路径,显著提高了内存带宽。同样造价更高,因此整体尺寸小于DDR。

(3)Main Memory
Main Memory就是指的我们常说的DDR或者广义一些就是内存,也是由dram构成。相对于HBM少了 TSV因造价便宜一些使得我们可以将其的尺寸定的更大一些。

综观上述三个结构,可见sram上的数据要和计算单元紧密合作所以要求访问速度快,而HBM主打带宽高方便计算大量吞吐,而cpu内存就是主打量大。所以我们不难看出来这个内存尺寸一定是类似金字塔型的。

2.2 bandwidth

论文中有个出处:
[43] Andrei Ivanov, Nikoli Dryden, Tal Ben-Nun, Shigang Li, and Torsten Hoefler. Data movement is all you need: A case study on optimizing transformers. Proceedings of Machine Learning and Systems, 3: 711–732, 2021.
这个出处详细分析了roofline对transformer的影响。roofline主要讲解的是什么时候计算被memorysize限制,什么时候被计算带宽限制,维基百科和知乎的一篇文讲的很好:
https://en.wikipedia.org/wiki/Roofline_model
https://zhuanlan.zhihu.com/p/34204282
我这里仅贴图,不赘述。
在这里插入图片描述

引用flash attention作者的结论就是 transformer 的瓶颈是:memory access.
原文:GPUs, compute speed has out-paced memory speed [61, 62, 63], and most operations in Transformers are
bottlenecked by memory accesses [43].

3、online softmax

online soft max其实是nvidia很早之前发的一篇论文:Online normalizer calculation for softmax。没有华丽的辞藻也没有高深的数学模型,就是实打实的将如何优化softmax,步步可推敲且严谨。

3.1 、safe softmax

这里有softmax原始公式:[y] = softmax([x])定义为:
y i = e x i ∑ j = 1 V e x j y_i=\frac{e^{x_i}}{\sum^{V}_{j=1}{e^{x_j}}} yi=j=1Vexjexi,其中 x , y ∈ R V x,y∈\R^V x,yRV, 即都是长度为V的实数域向量。
上述公式按照online softmax 作者描述,是2输入1输出因此,写成伪代码则是:

Algorithm 1 Naive softmax
1: d0 = 0
2: for j in range(1, V):
3:     d[j] = d[j−1] + e^x[j]
5: for i in range(1, V):
6:     y[i](e^x[i])/dV

但实际电脑上的代码则为了保护溢出值防止计算误差变大,
因为 e x e^x ex中档x>64,当我们遇到token数量为1000甚至100000时(尤其fp16时经常遇到因为计算误差导致梯度消失不收敛),那么这个过程就非常不精确。因此这里将其减去最大,将 e x e^x ex拉到0~1之间,公式为:
y i = e x i − max ⁡ k = 1 V x k ∑ j = 1 V e x j − max ⁡ k = 1 V x k y_i=\frac{e^{x_i-\max^V_{k=1}x_k}}{\sum^{V}_{j=1}{e^{x_j-\max^V_{k=1}x_k}}} yi=j=1Vexjmaxk=1Vxkeximaxk=1Vxk,其中 x , y ∈ R V x,y∈\R^V x,yRV, 即都是长度为V的实数域向量。
且该文的作者指出,tensoflow、ptyroch、caffe2等主流构架都是用该数学公式。
因此我们有伪代码:

Algorithm 2 Safe softmax
1: m0 = −∞
2: for k in range(1, V):
3:     m[k] = max(m[k−1], x[k])
4: d[0] = 0
5: for j in range(1, V):
6:     d[j] = d[j−1] * e^m[j-1] + e^(x[j]−m[V])
7: for i ← 1, V do
8:     y[i] = e^(x[i]−m[V])/d[V]

对于这个algorithm 2,计算向量中的一个元素,就要访存4次 memory。因此需要提升这个softmax。

3.2、online softmax

为什么要看online softmax, 因为FlashAttention作者提到了这个。
online softmax作者灵感来自于:
[18] B. P. Welford. Note on a method for calculating corrected sums of squares and products. Technometrics, 4(3):419–420, 1962. URL https://amstat.tandfonline.com/doi/abs/10.1080/00401706.1962.10490022.
将一个元素的4次memory访问降为3次,伪代码如下:

Algorithm 3 Safe softmax with online normalizer calculation
1: m0 = −∞
2: d[0] = 0
3: #for k in range(1, V):#被优化掉了
4: #    m[k] = max(m[k−1], x[k])#被优化掉了
5: for j in range(1, V):
6:     m[j] = max(m[j-1], x[j-1])#计算优保留,但是不用直接访存
7:     d[j] = d[j−1] * e^m[j-1] + e^(x[j]−m[V])
7: for i ← 1, V do
8:     y[i] = e^(x[i]−m[V])/d[V]

同时,作者也证明了这个过程是safe的。写到这里 online softmax其实就是通过一个递推公式(代码第7行) d [ j ] = d [ j − 1 ] ∗ e m [ j − 1 ] + e ( x [ j ] − m [ V ] ) , d [ 0 ] = 0 d[j] = d[j−1] * e^{m[j-1]} + e^{(x[j]−m[V])}, d[0]=0 d[j]=d[j1]em[j1]+e(x[j]m[V]),d[0]=0, 通过该公式将第N个结果算出而不是通过访问memory读出。

在此之上,该文作者也提出了如何并行化上述代码,其实就是SIMD向量(说的简单些就是一次将多个同类型数据送入vector register进行并行计算)并行:
在这里插入图片描述
以及针对inference过程中的topK其实也用的是softmax,因此可以有新的优化算法从上面提到的3-pass(访问3次)降为2-pass(访问2次),于是又有如下伪代码:

Algorithm 4 Online softmax and top-K
1: m0 = −∞
2: d0 = 0
3: u = {−∞, −∞, . . . , −∞}T# u ∈ R^(K+1) The 1st K elems will hold running TopK values
4: p = {1,1, . . .,1}T # p ∈ Z^(k+1) The 1st K elems will hold running TopK values' indices 
5: for j in range( 1, V):
6:      m[j] = max (m[j−1], x[j] )
7:      d[j] = d[j−1] * (e^(m[j−1]−m[j]) + e^(x[j]−m[j]))
8:      u[K+1] = x[j] # Initialize K + 1 elem with new value from input vector
9:      p[K+1] = j# Initialize K + 1 elem with new value from input vector's index
10:
11:     # Sort u in descending order, permuting p accordingly. The first K elements are 
12:     # already sorted, so we need just a single loop, inserting the last element in 
13:     # the correct position.
14:     _k = K 
15:     while k ≥ 1 and u[_k] < u[_k+1]:
16:         swap(u[_k], u[_k+1])
17:         swap(p[_k], p[_k+1])
18:         _k = _k − 1
19: for i in range( 1, K): # The algorithm stores only K values and their indices
20:     v[i] = e^(u[i]−m[V])/d[V]
21:     z[i] = p[i]

基本上追上原始softmax的效率
在这里插入图片描述

4、selective gradient checkpointing

本小节来自于论文作者的引用[10]
[10] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.
[34] Andreas Griewank and Andrea Walther. Evaluating derivatives: principles and techniques of algorithmic differentiation. SIAM, 2008.[34]不公开,且技术点在[10]中也有说明,我这里用[10]举例说明。 这里不是attention 的一个主要特点,因此提一下。

还有,为什么要单独开这么一节,因为这些数据非常亮眼:
1 内存消耗度直接干成O n \sqrt{n} n ,部分case变成O(logn)
2 残差网络48g干成7g

4.1 memory sharing

在这里插入图片描述
这篇引用论文的作者思路就是,在反向通路中,我们分析不同数据流的依赖关系,进而得到哪些节点的数据输入输出是一样的,那么这个节点依赖的memory就是share memory。
在这里插入图片描述
在这里插入图片描述

4.2 gradient checkpointing

这篇引用论文作者也点出了技术点:gradient checkpointing。该技术其实是在反向传播中挑选用于计算梯度的关键点。这些关键点就是 check point,如果全储存所有的中间结果代价较大,选取基哥点那么可以大大减少内存负担。
那么问题来了,如何选取?在[10]的引用文章《Implementation of Checkpointing for the Reverse or Adjoint Mode of Computational Differentiation》中的 revolve过程说明了该算法。因为我们重点是将flash attention详细展开会非常多,我这里仅做说明。

5、FlashAttention

基于上面的基础,作者首先分析了transformer 的瓶颈在于访问IO的次数,也就是频繁IO增加了耗时。因此角度都从这个方面切入。

5.1、flash attention IO complexity:

按照第1小节的公式,我们有qkv在硬件上的实现方法基本就是先算qk 接着softmax,最后乘以V,访存过程都放入HBM,因此有如下复杂度分析:
(1) Q , K , V ∈ R N × d Q,K,V ∈ \R^{N×d} Q,K,VRN×d因此 Q × K T Q×K^T Q×KT的计算复杂度为O(N×N×d),
(2)因为SRAM的宽度为M,因此访存次数为 N 2 d M \frac{N^2d}{M} MN2d
(2)softmax 如果用英伟达的 online softmax 则是倍数乘以3,因为这里是线性倍数不是指数倍数,因此可以暂时忽略。
(3)softmax之后的结果再乘以V,其计算复杂度再次增加N×d,但我们会这样写:

for(d) #d for O
    # calc the S  ################ start ####
	for(N) #N for Q
		for(N) #N for trans(K)
			for(d) # dot(Q,K), d is for both Q and trans(K)
	# calc the S ################# end  #####
	P = softmax(S)
	for(N) # N for P
		for(d) #  d for P*V

在这里插入图片描述

(4) 不考虑是否能恰好对齐sram长度,以及其他情况,假设每次都可以度满sram,因此访问sram的次数需要除以M,因为从HBM一次最多读M大小的数据。
因此得到文中的flash attention 的 IO复杂度(IO complexity)为O( N 2 d 2 M \frac{N^2d^2}{M} MN2d2)。

5.2、standard attention IO complexity:

S = Q ∗ K T , P = s o f t m a x ( S ) , O = P V ∈ R N × d S= Q*K^T, P = softmax(S) ,O = PV ∈ \R^{N × d} S=QKT,P=softmax(S)O=PVRN×d在这里插入图片描述
对于standard IO 复杂度,按照作者提供的这个过程,我们有:
(1)load Q,K ,write S: N d + N d + N 2 Nd + Nd + N^2 Nd+Nd+N2
(2)read S, softmax(S), write P: read S 不增加,因为本身就在 HBM中,接着用了 softmax(s)就是 *3,即 N 2 ∗ 3 N^2*3 N23,注:线性系数不用出现在复杂度,因此这里是 N 2 N^2 N2
(3)load P and V compute PV write O: N 2 + N d N^2 + Nd N2+Nd
(4)return O:略
所以作者说standard attention IO复杂度是 Ω ( N d + N 2 ) Ω(Nd + N^2) Ω(Nd+N2),即下限。

5.3、IO complexity compare:

flash attention 的 IO 复杂度为:O( N 2 d 2 M \frac{N^2d^2}{M} MN2d2)
而standard 的 IO复杂度为: Ω ( N d + N 2 ) Ω(Nd + N^2) Ω(Nd+N2)
看起来 flash attention的更多一些,但这里需要注意:
在这里插入图片描述
文中举出来了一种典型数据尺度:Nsequence lenght :1024, d head dimension:64~128,M为100KB。
`这种数量级起步,因此这中IO减少的思路还是非常可观的。作者也贴出了数据,证明效果。论文作者特别喜欢用`wall-clock时间`,wall-clock时间就是实际运行时间不关心中间过程如何也不用芯片频率做换算,就是`实打实,我们看墙上的表开始然后结束的时间`。下图为原文截图,作者以此说明flash attention的炸裂效果。
在这里插入图片描述

5.4、flash attention的优化手段

在描述self attention 算法前,我这里需要明确一下过程:

5.4.1 tiling

tiling 就是大家经常做的分tile。这里按照行(raw)分。
论文中:
sram_size: 原文是M,很容易和向量混淆,我这里 sram_size 替代更清晰。
如果 x ∈ R B x∈\R^B xRB,那么有如下定义:

  1. m ( x ) = m a x ( x ) = m a x ( x 1 , x 2 , x 3 , . . . . ) m(x) = max(x) = max(x_1, x_2, x_3, ....) m(x)=max(x)=max(x1,x2,x3,....) 是标量。
  2. f ( x ) = [ e x 1 − m ( x ) , . . . . , e x B − m ( x ) ] f(x) = [e^{x_1-m(x)}, ...., e^{x_B - m(x)}] f(x)=[ex1m(x),....,exBm(x)] 是向量。
  3. ℓ ( x ) = ∑ i f ( x ) i = e x 1 − m ( x ) + e x 2 − m ( x ) + . . . . . . \ell(x) = \sum_i{f(x)_i} \\=e^{x_1-m(x)} + e^{x_2-m(x)} + ...... (x)=if(x)i=ex1m(x)+ex2m(x)+...... 符号i有迷惑性,根据上下文就是求和。
  4. s o f t m a x = f ( x ) ℓ ( x ) softmax=\frac{f(x)}{\ell(x) } softmax=(x)f(x)

如果有 x = [ x ( 1 ) x ( 2 ) ] x=[x^{(1)}\quad x^{(2)}] x=[x(1)x(2)], 其中 x ( 1 ) , x ( 2 ) ∈ R B x^{(1)}, x^{(2)} ∈\R^B x(1),x(2)RB, 那么上面的4条又有如下定义:

  1. m ( x ) = m a x ( [ x ( 1 ) , x ( 2 ) ] ) = m a x ( m a x ( x ( 1 ) ) , m a x ( x ( 2 ) ) ) = m a x ( m a x ( x 1 ( 1 ) , x 2 ( 1 ) , x 3 ( 1 ) , . . . . ) , m a x ( x 1 ( 2 ) , x 2 ( 2 ) , x 3 ( 2 ) , . . . . ) m(x) = max([x^{(1)}, x^{(2)}]) \\= max(max(x^{(1)}), max(x^{(2)})) \\=max(max(x_1^{(1)}, x_2^{(1)}, x_3^{(1)}, ....), max(x_1^{(2)}, x_2^{(2)}, x_3^{(2)},....) m(x)=max([x(1),x(2)])=max(max(x(1)),max(x(2)))=max(max(x1(1),x2(1),x3(1),....),max(x1(2),x2(2),x3(2),....) 是标量。
  2. f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] f(x) = [e^{m( x^{(1)} ) - m(x) } f( x^{(1)} )\quad e^{m( x^{(2)} ) - m(x) } f( x^{(2)} ) ] f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))] 这里不加逗号而是空格,我理解作者其实希望两个 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2)写在一起不增加维度所以采用了个这么写法。
    其中: e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) [ e x 1 − m ( x ( 1 ) ) , . . . . , e x B − m ( x ( 1 ) ) ] = e m ( x ( 1 ) ) − m ( x ) e x 1 − m ( x ( 1 ) ) , . . . . . , e m ( x ( 1 ) ) − m ( x ) e x B − m ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) e x 1 − m ( x ( 1 ) ) , . . . . . , e m ( x ( 1 ) ) − m ( x ) e x B − m ( x ( 1 ) ) = e x 1 − m ( x ) , . . . . . , e x B − m ( x ) e^{m( x^{(1)} ) - m(x) } f( x^{(1)} ) \\=e^{m( x^{(1)} ) - m(x) } [e^{x_1-m( x^{(1)} )}, ...., e^{x_B - m( x^{(1)} )}] \\=e^{m( x^{(1)} ) - m(x) }e^{x_1-m( x^{(1)} )}, .....,e^{m( x^{(1)} ) - m(x) }e^{x_B - m( x^{(1)} )} \\=e^{ \cancel{ m( x^{(1)} ) } - m(x) } e^{ x_1 - \cancel{m( x^{(1)} )} }, .....,e^{ \cancel{ m( x^{(1)} ) } - m(x) } e^{ x_B - \cancel{ m( x^{(1)} ) } } \\=e^{ x_1-m(x) }, .....,e^{x_B - m(x) } em(x(1))m(x)f(x(1))=em(x(1))m(x)[ex1m(x(1)),....,exBm(x(1))]=em(x(1))m(x)ex1m(x(1)),.....,em(x(1))m(x)exBm(x(1))=em(x(1)) m(x)ex1m(x(1)) ,.....,em(x(1)) m(x)exBm(x(1)) =ex1m(x),.....,exBm(x) 也是和定义相同。
  3. ℓ ( x ) = ℓ ( x ( 1 ) x ( 2 ) ) = [ e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) ] \ell(x) = \ell( x^{(1)} \quad x^{(2)} ) \\= [e^{m( x^{(1)} ) - m(x) } \ell( x^{(1)} )\quad e^{m( x^{(2)} ) - m(x) } \ell( x^{(2)} ) ] (x)=(x(1)x(2))=[em(x(1))m(x)(x(1))em(x(2))m(x)(x(2))]这里推导和2.中过程一样,不赘述。
  4. s o f t m a x = f ( x ) ℓ ( x ) softmax=\frac{f(x)}{\ell(x) } softmax=(x)f(x) 同理,不赘述。

通常我们分tile可以减少横跨地址带来的cache miss以及提高memory的利用率。但是对于flash attention最主要的目的是将输入减小。减小后可以一次放在 sram上做完一个tile的 s o f t m a x ( K i Q i T V i ) softmax(K_iQ^T_iV_i) softmax(KiQiTVi),然后再把每个tile的结果拼在一起。至于作者为什么分为4d,个人认为就是一个经验值,也可以8d或16d,就像transform里的 d k \sqrt{d_k} dk 。但是这里的4是fp32类型,因为sram的单位是byte,github上工程使用fp32实现,是4个byte后面再求得尺寸时可以得到验证。
在这里插入图片描述

5.4.2 recomputation:

这里作者首先说这里是为在反向传播时,不存储 O ( N 2 ) O(N^2) O(N2)尺寸的数据。如果没有优化需要存储S和P, S , P ∈ R N × N S,P ∈ \R^{N × N} S,PRN×N才能计算出有关KQV的梯度。但是如果存储输出O和矩阵归一化统计信息 ( m , ℓ ) (m, \ell) (m,)可以重新计算出有关KQV的 S 和 P。这种方法可以看成是selective gradient checkpointing,在第4章节说过这个方法。

5.4.3 implementation:

在这里插入图片描述
Require:
Q , K , V ∈ R N × d , S = Q ∗ K T Q,K,V ∈ \R^{N × d},S= Q*K^T Q,K,VRN×dS=QKT P = s o f t m a x ( S ) P = softmax(S) P=softmax(S) O = P V ∈ R N × d O = PV ∈ \R^{N × d} O=PVRN×d
line 1:
block size有两种:
(1) B c = c e l l i n g ( s r a m s i z e / 4 d ) B_c = celling(sram_size / 4d) Bc=celling(sramsize/4d)
(2) B r = m i n ( c e l l i n g ( s r a m s i z e / 4 d ) , d ) B_r = min(celling(sram_size / 4d), d) Br=min(celling(sramsize/4d),d)

line 2:
初始化 O = ( 0 ) N × d ∈ R N × d O=(0)_{N×d}∈\R^{N × d} O=(0)N×dRN×d ℓ = ( 0 ) N × d ∈ R N \ell=(0)_{N×d}∈\R^{N} =(0)N×dRN m = ( − ∞ ) N ∈ R N m=(-∞)_{N}∈\R^{N} m=()NRN
这里注意的是m,m维度是N。且上述内存都在HBM中。

line3:
(1) Q分为 T r 个 = c e l l i n g ( N / B r ) T_r个=celling(N/B_r) Tr=celling(N/Br)个 blocks Q 1 , Q 2 , . . . . , Q T r Q_1,Q_2,....,Q_{Tr} Q1,Q2,....,QTr 共计 B r × d B_r×d Br×d个。
(2) K和V分 T c 个 = c e l l i n g ( N / B c ) T_c个=celling(N/B_c) Tc=celling(N/Bc)个blocks K i K_i Ki V i V_i Vi都是 B c × d B_c×d Bc×d 个。
(3) 注意这里有个tricky的地方: B c B_c Bc 是sram_size/4划分,但是原文中却这样描述:
在这里插入图片描述
这里是因为sram_size是byte大小,因此按照M/4d划分。但是运行数据是fp32(4个btye)所以是 B r × d B_r×d Br×d

line4:
同理划分O, ℓ \ell 各为 B r × d B_r×d Br×d,m 分为 B r B_r Br

line5:
for Tc 循环:

line6:
K j 和 K_j和 KjV_i$从HBM上加载到sram中。

line7:
for Tr 循环:

line8:
从HBM中读取 Q i , O i , ℓ i , m i Q_i,O_i,\ell_i,m_i Qi,Oi,i,mi
line9:
计算 S i j = Q i K j T ∈ R B r × B c S_{ij}=Q_iK^T_j ∈ \R^{B_r × B_c} Sij=QiKjTRBr×Bc
line10:
(1) 计算 m ~ i j = r o w m a x ( S i j ) ∈ R B r \widetilde m_ij=rowmax(S_{ij})∈\R^{B_r} m ij=rowmax(Sij)RBr到这里,我用一张图说明过程:
在这里插入图片描述
(2) 计算 P ~ i j = e x p ( S i j − m ~ i j ) ∈ R B r × B c \widetilde P_{ij}=exp(S_{ij} - \widetilde m_{ij}) ∈ \R^{B_r × B_c} P ij=exp(Sijm ij)RBr×Bc,注意这里算出来的 m ~ i j \widetilde m_{ij} m ij是需要在行方向复制扩充才能匹配 S i j S_{ij} Sij的尺寸,做的工作就是
第3章节中说的online softmax的safe过程。同样依照online softmax,这里需要求和得到softmax的分母。
在这里插入图片描述
line11:

因为循环是不断迭代的,对于第一次,计算出来的 最大值 和 sum求和都是局部的最大值和求和,因此我们需要5.4.1 中tiling部分的2个公式:
ℓ ( x ) = ℓ ( x ( 1 ) x ( 2 ) ) = [ e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) ] \ell(x) = \ell( x^{(1)} \quad x^{(2)} ) = [e^{m( x^{(1)} ) - m(x) } \ell( x^{(1)} )\quad e^{m( x^{(2)} ) - m(x) } \ell( x^{(2)} ) ] (x)=(x(1)x(2))=[em(x(1))m(x)(x(1))em(x(2))m(x)(x(2))]
f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] f(x) = [e^{m( x^{(1)} ) - m(x) } f( x^{(1)} )\quad e^{m( x^{(2)} ) - m(x) } f( x^{(2)} ) ] f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))]
随着循环不断更新,更新的数值是靠前面乘的系数。因此我们每次循环都会从新得到 m i n e w m_i^{new} minew ℓ i n e w \ell_i^{new} inew
同时也得到 P = s o f t m a x ( S ) P = softmax(S) P=softmax(S)

line12:
计算出 O = P V ∈ R N × d O = PV ∈ \R^{N × d} O=PVRN×d
** line13:**
将算出的 ℓ i \ell_i i m i m_i mi存到HBM中,因为反向传播要用。

5.6、flash attention IO complexity :

作者在写完上述算法后,又进行了理论分析。并且对反向传播也写了详细的优化算法。
我将在下一篇文章进行分享


  • 15
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值