解读 FlashAttention

       

       欢迎来到雲闪世界。当我思考下一个系列的主题时,解释注意力机制如何工作的想法立即浮现在我的脑海中。事实上,在推出新系列时,从基础开始是一种明智的策略,而大型语言模型 (LLM) 正是人们热议的话题。

然而,互联网上已经充斥着有关注意力的故事——注意力的机制、功效和应用。所以,如果我想在开始之前就让你不打瞌睡,我必须找到一个独特的视角。

那么,如果我们从不同的角度来探讨注意力的概念,结果会怎样呢?与其讨论注意力的好处,不如研究它带来的挑战,并提出一些缓解挑战的策略。

秉承这一思路,本系列将重点介绍 FlashAttention:一种快速且内存高效的精确注意,具有 IO 感知功能。这个描述乍一看可能让人不知所措,但我相信到最后一切都会变得清晰。

本系列将遵循我们惯常的格式:四部分,每周发布一期。

您正在阅读的第一部分是 FlashAttention 的 ELI5 介绍,概述了它的动机、意义、使用传统注意力算法实现类似结果的挑战以及其背后的核心概念。

第二部分深入探讨了注意力机制的原理(是的,它很有必要,你无法逃避它)。此外,我们还将讨论原始注意力算法的可扩展性挑战。

在第三部分中,我们将详细探讨 GPU 内存层次结构——这是了解 FlashAttention 如何实现其效率的关键组件。什么是 HBM,什么是 SRAM,以及我们如何充分利用每个组件。我相信这是本系列中最有趣的部分之一。

最后,第四部分也是最后一部分,我们将重新回顾 FlashAttention,我们将使用 PyTorch 从头开始重新实现该算法。在这篇文章中,我保证一切都会水到渠成,让您自信地说您了解 Attention 是什么、它如何运作以及 FlashAttention 如何增强它。

让我们开始吧!

寻求更深层次的理解

既然我们已经站在起跑线上,我们就要控制好节奏——暂时还不需要冲刺。我的目标是,使用 ELI5 方法,慢慢揭开 FlashAttention 背后的故事,层层剥开它为何如此重要却又如此具有挑战性,同时保持故事的崇高和轻松。

开发 FlashAttention 的动力源自 AI 社区对处理越来越长的数据序列的不断增长的野心。这种追求不仅仅是出于学术上的好奇,也不是出于创建更大模型和更大 GPU的愿望,以追随《Big, Bigger, Biggest》纪录片中的趋势。这种能力的潜在好处是深远的,可以更深入地理解书籍和说明书等复杂文本。这与通过创建更高分辨率图像的数据集在计算机视觉领域取得的进步相一致。

想象一下:你正在疯狂观看你最喜欢的节目,但每次播放新剧集时,你对上一集的记忆就会被抹去。这真是一场噩梦般的观影体验!说实话,每集开始前的“变形金刚前瞻”部分的作用有限。

那么,是什么阻止我们在更长的序列上训练我们的 LLM 并扩展它们的上下文窗口以适应整本书呢?很高兴你问这个问题!

扩展的挑战

建立更长序列模型的雄心遭遇了巨大的阻碍:注意力机制是当今 LLM 所采用的 Transformer 架构的核心,但随着序列长度的增加,该机制难以扩展。由于序列中的每个 token 都试图与其他每个 token 进行有意义的对话,无论它们相距多远,延长序列就像是为派对增加了更多客人。这听起来很有趣,直到你意识到你必须与每个人聊天。

问题有两个方面:注意力机制的计算复杂度不仅会随着序列长度的增加而呈二次方增长,而且内存利用率低也会阻碍其发展。序列长度每增加一倍,对 GPU(相对较慢的)高带宽内存 (HBM) 的读写需求就会增加四倍(这个名字不太好),从而形成瓶颈,阻碍可扩展性。

为了将上一句话形式化,我们说注意力的算术强度(定义为算术运算与记忆访问的比率)偏低。虽然这个术语对于掌握大局可能并不重要,但现在你可以在谈话中使用它,听起来更聪明。

有什么想法?

FlashAttention 是一种革命性的方法,旨在正面应对这些挑战。通过使注意力算法了解其环境,从而使其具有 I/O 感知能力,FlashAttention 将传统方法转变为一种更快、更节省内存的过程。

核心思想是通过利用更智能的数据加载策略来最大限度地减少冗余的 HBM 读写,例如重新使用经典的平铺技术,并动态重新计算反向传递所需的一切。现在,这句话加上下面的图,可能会让你濒临放弃。但坚持下去!当我们到达本系列的结局时,一切都会水到渠成,就像拼图的最后一块一样令人满意。

左图:FlashAttention 使用平铺来防止在(相对较慢的)GPU HBM 上实现较大的 𝑁 × 𝑁 注意力矩阵(虚线框)。在外循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块并将它们加载到快速的片上 SRAM 中。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。右图:在 GPT-2 上加速 PyTorch 注意力实现。FlashAttention 不会读取和写入较大的 𝑁 × 𝑁 注意力矩阵到 HBM,从而使注意力计算速度提高了 7.6 倍。— Dao, Tri, et al. “Flashattention:具有 io 感知的快速且内存高效的精确注意力。”《神经信息处理系统进展》 35 (2022):16344–16359。

但是,即使你现在感到迷茫,如果你仔细看看右边的最后一张图,你就会发现注意力机制的大部分计算时间都花在了计算 Softmax 和 Dropout 之类的东西上,而它很少进行矩阵乘法。由于元素级操作(例如计算 Softmax 或通过 Dropout 层将一堆激活归零)主要受内存限制,因此整个注意力操作都是受内存限制的操作。

换句话说,我们基本上处于待机模式,无所事事地等待数据从主内存传输回来。仅凭这一点,您就应该知道需要做些什么才能加快速度!

平铺的力量

FlashAttention 策略的核心是平铺概念,即将大型矩阵分解为较小的块。这些块可以在 GPU 的 SRAM 中高效处理,从而显著加快计算速度。

平铺允许注意力操作以块的形式执行,从而减少内存开销并能够处理更长的序列,而不会增加内存访问量。

把拼贴想象成尝试制作一个巨大的披萨,但不是试图将整个巨大的披萨塞进烤箱,而是巧妙地将其分成易于管理的切片。然后,这些切片可以在 GPU 的 SRAM 中尺寸较小的烤箱中逐个完美地烘烤,然后,您可以将它们合并回去参加市政厅会议。

更多操作≠更多时间

FlashAttention 的另一个特点是它在算法的反向传递中使用了重新计算。FlashAttention 巧妙地只存储快速重新计算注意力所需的内容,而不是存储可能非常大的整个注意力矩阵。

再次强调,这些术语现在可能听起来令人困惑,但我们最终会将其全部解开。我无法想出一个类比来解释这一点!但相信我,这是整个想法中最简单的部分,所以我们会在准备好时重新讨论它。

展望未来

随着我们深入研究本系列,我们将揭开注意力层的秘密,仔细研究 GPU 内存的工作原理,最后以令人兴奋的 FlashAttention 结束。

感谢关注雲闪世界。(Aws解决方案架构师vs开发人员&GCP解决方案架构师vs开发人员)

 订阅频道(https://t.me/awsgoogvps_Host)
 TG交流群(t.me/awsgoogvpsHost)


 

  • 14
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值