大模型论文:FlashAttention Fast and Memory-Efficient Exact Attention with IO-Awareness(效率提升)

大模型论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness(效率提升)

文章地址:2205.14135 (arxiv.org)





摘要

Transformer 在处理长序列时速度慢、内存开销大,其原因在于自注意力机制的时间和内存复杂度与序列长度呈二次关系。尽管已有一些近似注意力机制尝试通过牺牲模型精度来降低计算复杂度,但这些方法往往无法在实际运行时间上实现提速。我们认为,这些方法缺少的一个核心原则是 IO 感知(IO-aware) —— 即在算法设计中考虑 GPU 不同层级内存之间的数据读写开销。

我们提出了 FlashAttention,一种 IO 感知的精确注意力算法,通过 分块(tiling)机制 来减少 GPU 高带宽内存(HBM)与片上 SRAM 之间的数据读写操作。我们对 FlashAttention 的 IO 复杂度进行了分析,结果显示它相比标准注意力机制显著减少了对 HBM 的访问次数,并在一定范围的 SRAM 大小下达到了最优。

此外,我们还将 FlashAttention 拓展到 块稀疏注意力(block-sparse attention),得到了一个 近似注意力算法,该算法比现有所有近似方法都要快。

在实际应用中,FlashAttention 在 Transformer 训练中表现出显著优势:

  • 在 BERT-large(序列长度 512)上,相比 MLPerf 1.1 的训练速度记录,整体训练时间加速 15%
  • 在 GPT-2(序列长度 1000)上,加速达到 3 倍
  • 在 Long Range Arena(序列长度 1000–4000)任务中,加速达到 2.4 倍

此外,FlashAttention 及其块稀疏版本还使 Transformer 能够处理更长的上下文,从而带来更高的模型质量:

  • GPT-2 上的困惑度(perplexity)提升了 0.7
  • 长文档分类任务上准确率提升 6.4 个百分点
  • 成为第一个在 Path-X(序列长度 16K,准确率 61.4%)Path-256(序列长度 64K,准确率 63.1%) 任务中取得优于随机猜测表现的 Transformer 模型

简而言之,FlashAttention 通过 IO 感知设计,真正做到了在实际训练时间上实现加速,同时提升了模型的能力与质量。





问题分析

  • Transformers在长序列下自注意力机制复杂度为 O(n²)(n为序列长度),导致:

    • 内存爆炸
    • 训练时间剧增
    • 实际训练时间(wall-clock time)并无提升
  • 自注意力模块的时间与内存复杂度随着序列长度呈二次增长(O(n²))。因此一个关键问题是,如何提升注意力机制的速度和内存效率,从而帮助Transformer在长序列场景中更好运行

  • 本文认为之前的方法没有考虑到的东西是:使注意力算法具备 IO 感知能力。即:仔细考虑不同类型GPU内存之间的数据读写成本,比如快的SRAM与慢的HBM之间的传输

  • 为什么仅优化FLOPs(浮点运算数)不够?

    • 许多近似注意力方法旨在降低注意力计算的计算量和内存需求。这些方法包括从稀疏近似sparse-approximation 到低秩近似low-rank approximation 。尽管这些方法将计算需求降低到与序列长度呈线性或接近线性的关系,但其中许多方法与标准注意力相比,并没有实现实际运行时间的加速,也没有得到广泛应用。一个主要原因是,它们专注于减少浮点运算次数(FLOP reduction)(这可能与实际运行速度并无关联),并且往往忽略了内存访问(memory access, IO)带来的开销

    • 🧠 举个例子:

      你在 Vim 编辑一个巨大的文件,如果每一次操作都要从慢速硬盘读写,哪怕你操作再快也没用。FlashAttention 就是尽量在缓存(SRAM)中完成计算,减少“去硬盘”(HBM)的次数
      FlashAttention示意图





FlashAttention

  • 本文提出 FlashAttention,一种在保证准确性的前提下,大幅减少内存访问的注意力机制。目标是避免对 HBM 中的注意力矩阵进行读写操作

  • 我们采用两个关键技术(上图):

    1. 分块计算(Tiling):将输入分成多个小块(block),逐块处理注意力,从而避免一次性读写整个矩阵
    2. 重计算(Recompute)优化:在前向传播阶段只存储 Softmax 缩放因子,在反向传播时用它来快速重建注意力值,而不是从HBM读取完整矩阵
  • 相比于原始Transformer的自注意力,FlashAttention具有以下特点:

    • 在不访问整个输入的情况下计算 softmax reduction(归约)
    • 在反向传播中不存储大型的中间注意力矩阵
    • 重新组织注意力计算过程,将输入分割成块,并对输入块进行多次遍历,从而逐步执行 softmax 归约(也称为平铺)
    • 存储前向传播中的 softmax 归一化因子,以便在反向传播中在芯片上快速重新计算注意力,这比从 HBM 读取中间注意力矩阵的标准方法更快
  • 本文用 CUDA 实现 FlashAttention,将所有操作融合在一个 GPU kernel 中。这不仅加速了模型训练(在 GPT-2 上最高提升 7.6 倍),还降低了内存消耗(因 IO 减少而线性增长)

  • 复杂度分析:FlashAttention 的HBM访问复杂度为 O ( N 2 ⋅ d 2 ⋅ M − 1 ) O(N²·d²·M⁻¹) O(N2d2M1),其中: N N N 是序列长度,$d 是头维度, 是头维度, 是头维度,M 是 S R A M 大小。而传统注意力需要 是SRAM大小。而传统注意力需要 SRAM大小。而传统注意力需要 O(Nd + N²) $次 HBM 访问。FlashAttention最多能减少9倍的HBM访问

  • 分块(Tiling)机制:将序列按块分割,每次只处理一个 block:

    • 比如:处理一个 1K 的序列时,我们将其分成 8 个 128 长度的 block,分别计算 attention
    • 每次操作只涉及小范围的局部缓存,极大减少了全矩阵访问
  • 重计算技巧:FlashAttention 在反向传播时不再读取存储在HBM的注意力矩阵,而是快速用Softmax因子重建结果

  • 原始的self-attention实现:给定输入序列 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N \times d} Q,K,VRN×d,其中 N N N 是序列长度, d d d 是头维度,我们想要计算注意力输出 O ∈ R N × d O \in \mathbb{R}^{N \times d} ORN×d
    S = Q K ⊤ ∈ R N × N , P = s o f t m a x ( S ) ∈ R N × N , O = P V ∈ R N × d S = QK^\top \in \mathbb{R}^{N \times N}, \quad P = \mathrm{softmax}(S) \in \mathbb{R}^{N \times N}, \quad O = PV \in \mathbb{R}^{N \times d} S=QKRN×N,P=softmax(S)RN×N,O=PVRN×d

    其中 softmax 是按行应用的

    标准注意力机制的实现会将矩阵 S S S P P

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小白学C++.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值