FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,2022.5
官方源码:https://github.com/hazyresearch/flash-attention

FlashAttention:具有IO感知功能的快速高效精确注意力

 摘要

        Transformer在长序列上速度慢且内存消耗大,因为自注意的时间和内存复杂度是序列长度的二次方。近似注意力方法试图通过权衡模型质量来解决这个问题,以降低计算复杂度,但通常不能实现挂钟加速。我们认为,一个缺失的原则是使注意力算法IOAWARE占读和写之间的GPU内存的水平。我们提出了FlashAttention,一个IO感知的精确注意算法,使用平铺来减少GPU高带宽内存(HBM)和GPU片上SRAM之间的内存读/写次数。我们分析了FlashAttention的IO复杂性,表明它需要比标准注意更少的HBM访问,并且对于一系列SRAM大小是最佳的。我们还将FlashAttention扩展到块稀疏注意力,产生一个近似注意力算法,比任何现有的近似注意力方法都要快。FlashAttention训练Transformers的速度比现有基线更快:在BERT-large上实现15%的端到端挂钟加速(以下长度512)相比,MLPerf 1.1的训练速度记录,GPT-2(seq.长度1 K),以及2.4倍的远程竞技场(seq.长度1 K-4K)。FlashAttention和块稀疏FlashAttention支持Transformers中更长的上下文,产生更高质量的模型(GPT-2上的复杂度提高了0.7,长文档分类提升了6.4点)和全新的功能:第一个在Path-X挑战赛中获得超越机会表现的变形金刚(以下)。长度16 K,61.4%准确度)和Path-256(seq.长度64 K,63.1%准确度)。

图1

 

       我们提出了FlashAttention,一种新的注意力算法,计算精确的注意力与少得多的内存访问。我们的主要目标是避免从HBM阅读和写入注意力矩阵。这需要(i)在不访问整个输入的情况下计算softmax约简(ii)不存储用于反向传递的大的中间注意力矩阵。我们采用两种成熟的技术来应对这些挑战。(i)我们重新构造注意力计算,将输入拆分为块,并在输入块上进行多次传递,从而逐步执行softmax减少(也称为平铺)。(ii)我们存储前向传递的softmax归一化因子,以便在后向传递中快速重新计算片上注意力,这比从HBM阅读中间注意力矩阵的标准方法更快。我们在CUDA中实现了FlashAttention,以实现对内存访问的细粒度控制,并将所有的注意力操作融合到一个GPU内核中。即使由于重新计算而增加了FLOP,我们的算法运行速度更快(在GPT-2上高达7.6倍[67],图1右),并且与标准注意力相比使用更少的内存-序列长度线性,这要归功于HBM访问量的大幅减少。

        左图:FlashAttention使用平铺来防止大型×注意力矩阵(虚线框)在(相对)较慢的GPU HBM上物化。在外部循环(红色箭头)中,FlashAttention循环通过K和V矩阵的块,并将它们加载到快速片内SRAM。在每个块中,FlashAttention循环Q矩阵的块(蓝色箭头),将它们加载到SRAM,并将注意力计算的输出写回HBM。右图:加速GPT-2上的PyTorch实现。FlashAttention不会将大型×注意力矩阵读写到HBM,从而使注意力计算加速7.6倍。

1.标准注意力实施

在标准的Attention中,Q、K、V作为输入,大小为N×d,如下图所示,在计算中需要存储中间值S和P到显存HBM(High Bandwidth Memory)中,这会极大占用HBM。

 

图2

  2 FlashAttention的实现

FlashAttention旨在避免从 HBM中读取和写入注意力矩阵,这需要做到:

目标1:在不访问整个输入的情况下计算softmax函数;
目标2:在后向传播中不能存储中间注意力矩阵。

 

针对目标1

已知SRAM、HBM、DRAM存储容量依次升高,数据IO速度依次降低,如下图所示。那么可以将Q,K,V矩阵划分成多个小的子块,这些子块的大小恰好能从HBM加载进SRAM中,循环将子块传递进SRAM中以增量方式计算出Softmax值。

针对目标2


在后向传播中不存储中间注意力矩阵,以FlashAttention所提供的算法为例,通过对比标准Attention算法在实现过程中,标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。

在Flash Attention的前向计算算法(上图 Algorithm 2中)中我们可以看出,Flash Attention算法并没有将S、P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。
优点是:即使由于重新计算导致FLOPs增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。实验对比如下图。

 左:GPT-2介质的标准注意力和FlashAttention的向前+向后运行时间(seq.长度1024,头部变暗。64,16个磁头,批量大小为64)。HBM访问是影响运行时间的主要因素。中间:FlashAttention的转发运行时间(seq.长度1024,头部变暗。64,16个磁头,批量大小为64)。更少的HBM访问导致更快的运行时间,在一定程度上。右:运行时(对于seq.块稀疏FlashAttention的长度为4K)比FlashAttention快一个与稀疏度成比例的因子。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值