FlashAttention介绍(1.0版本)

一、为什么需要FlashAttention

在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存并消耗了大量的计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型面临的重要问题。

二、什么是FlashAttention

1. FlashAttention的灵感来源

FlashAttention:IO感知的精确注意力。通过GPU硬件的特殊设计,尽可能的避免在HBM中读取或者写入注意力矩阵。
在这里插入图片描述
以自注意力的计算为例,其计算流程为:
1)首先从全局内存(HBM)中读取矩阵Q和K,并将计算好的矩阵S再写入全局内存。
2)之后再从全局内存中获取矩阵S,计算Softmax得到矩阵P再写入全局内存。
3)之后读取矩阵P和矩阵V ,计算得到矩阵O,写入HBM。
我们可以发现在自注意力分数的计算过程中,不但产生了中间注意力矩阵S、P,并且还对HBM频繁的写入S和P。虽然HBM的带宽速度达到1.5TB/s,但GPU拥有极多的线程,平均算下来带宽速度其实并不快。因此,在使用GPU训练时,训练时间的长短一般会受到HBM的限制,读写过于频繁或读写带宽较慢,跟不上计算速度,导致训练时间延长。
由上图可以发现SRAM相对于HBM拥有更高的带宽速度。因此,FlashAttention充分使用SRAM,并设计了两个目标:
1)在不访问整个输入的情况下计算softmax函数
2)不能存储中间注意力矩阵S、P
解读·:在不访问整个输入的情况下计算softmax函数,这是因为SRAM内存有限,因此需要输入进行分块,一次只计算一小部分的结果。不能存储中间注意力矩阵S、P,是为了防止频繁读写。
FlashAttention的实现难点在于softmax的分块计算,因为计算softmax的归一化因子(分母)时需要获取到完整的输入数据。

总而言之,FlashAttention并未将S、P整体写入HBM,而是通过分块计算存储Softmax的归一化因子,然后重新计算注意力,并充分利用了SRAM。

2. FlashAttention的分块计算

以对向量[1,2,3,4]计算softmax为例,阐述如何进行分块计算(分成两块[1,2]和[3,4]进行计算)
block1为:
在这里插入图片描述
block2为:
在这里插入图片描述
合并得到完整的softmax结果
在这里插入图片描述
如上图所示,m为块的最大值,f为softmax的分子,l为softamx的分母。
需注意:此处进行softmax归一化时,块的各个元素减去了块的最大值m,这是为了提高数值稳定性。softmax存在一个问题,当xi的值非常大时,e^xi也会非常快的增长,这就可能会导致数值溢出,这种情况在注意力机制中非常常见,因为注意力分数看会很大,因此,flashAttention在计算softmax前,会从每个注意力分数中减去最大的注意力分数,称之为”减去最大值“,目的时将注意力分数的数值范围缩小,使得在计算指数时不会那么容易发生溢出。(softmax归一化使用e的幂次方的好处:1.防止出现负数 2.当值非常大时,使用指数函数可以扩大差异)
当计算出两个块的m,f,l后,即可进行合并计算出完整的softmax结果。

FlashAttention的分块计算如图所示,其实就是不断循环上述图中计算softmax过程(ps:自己在纸上写矩阵代入计算会更容易理解)
在这里插入图片描述

三、FlashAttention的优势

1)加快了计算。Flash Attention并没有减少计算量FLOPs,而是从IO感知出发,减少了HBM访问次数,重新计算虽然也导致FLOP 增加,但其运行的速度更快且使用的内存更少。
2)节省了显存。Flash Attention通过引入统计量,改变注意力机制的计算顺序,避免了实例化注意力矩阵 S,P。
3)精确注意力。不同于稀疏注意力,Flash Attention只是分块计算,而不是近似计算,Flash Attention与原生注意力的结果是完全等价的。

  • 16
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值