PyTorch全新API:几行代码实现不同注意力变体!兼具FlashAttention性能

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【Mamba/多模态/扩散】交流群

添加微信号:CVer111,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!

6322477ea99b7bd311a121238a69588e.png

转载自:机器之心 | 编辑:陈陈

用 FlexAttention 尝试一种新的注意力模式。

理论上,注意力机制就是你所需要的一切。然而在实际操作中,我们还需要优化像 FlashAttention 这样的注意力机制的实现。

尽管这些融合的注意力机制大大提高了性能,且支持长上下文,但这种效率的提升也伴随着灵活性的丧失。对于机器学习研究人员来说,这就像是一种「软件彩票」—— 如果你的注意力变体不适合现有的优化内核,你将面临运行缓慢和 CUDA 内存不足的困境。 

一些注意力变体包括因果注意力、相对位置嵌入、Alibi、滑动窗口注意力、PrefixLM、文档掩码、不规则张量、PagedAttention 等。更糟糕的是,人们通常希望将这些变体组合在一起!比如滑动窗口注意力 + 文档掩码 + 因果注意力 + 上下文并行,又比如 PagedAttention + 滑动窗口的组合。

下图左侧代表了当今的现状 —— 一些掩码 + 偏置 + 设置的组合已经有现成的内核实现。然而,各种选项的添加会导致设置呈指数级增长。更糟糕的是,这种方式不会支持新的注意力变体。 

ab2e0b8c8e95586addc31d3d7dd821ee.png

为了彻底地解决这个超立方体问题,PyTorch 团队引入了 FlexAttention,一个新的 PyTorch API。

  1. FlexAttention 是一个灵活的 API,允许用户使用几行惯用的 PyTorch 代码就能实现多个注意力变体。

  2. 团队人员通过 torch.compile 将其降低到一个融合的 FlashAttention 内核中 ,生成了一个不会占用额外内存且性能可与手写内核相媲美的 FlashAttention 内核。

  3. 利用 PyTorch 的自动求导机制自动生成反向传播。

  4. 最后,PyTorch 团队还可以利用注意力掩码中的稀疏性,从而显著改善标准注意力实现。

5746cd231ddc5687afc224848ec139ab.png

FlashAttention 1-3 版本的参与者 Tri Dao 对这项研究进行了转发并评论:这项研究使得很多技术都融合在一起了。

837e234119d3bd1ecdaaaa0514310998.png

FlexAttention

经典的注意力方程式如下:

6fa2e505a1074588cd454e53a00afb25.png

代码形式:

ebfd87a2497ba8fbf0e0711fd2f95ca2.png

FlexAttention 形式如下,其通过接受用户定义的函数 score_mod 来解决上述问题。

30f486810f59eaf59acae0654e61b148.png

代码形式:

9e011f7a4bcefb93f7b40fb8802e89f5.png

此函数允许用户在 softmax 之前修改注意力分数。研究人员发现,该函数最终足以满足大多数用户对注意力变体的需求。

具体而言,score_mod 如下:

5e73fa03df63558e956c831d28b35ab9.png

要应用此函数,可以将其实现为:

for b in range (batch_size):
    for h in range (num_heads):
        for q_idx in range (sequence_length):
            for kv_idx in range (sequence_length):
                modified_scores [b, h, q_idx, kv_idx] = score_mod (scores [b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)

最终的 API 具有令人惊讶的表达能力。

Score Mod 示例

全注意力

在这种情况下,score_mod 无操作,它接受分数作为输入,然后原样返回它们。

efac4e9a4f02e950df37d5038f7d097d.png

然后端到端的使用。

ef1207d1ec58206c40c79eb0f30934a7.png

相对位置编码

一种常见的注意力变体是相对位置编码。相对位置编码不是对查询和键中的绝对距离进行编码,而是根据查询和键之间的距离调整分数。

80668c7669f1861c4cb6cbee36c99693.png

需要注意的是,与典型实现不同,这不需要具体化 SxS 张量。相反,FlexAttention 会在内核中动态计算偏差值,从而显著提高内存和性能。

362763fb198c2002b3f063882ff9c742.png

Soft-capping

Soft-capping 是 Gemma 2 和 Grok-1 使用的一种技术,在 FlexAttention 中,它的形式是这样的:

4e317f72f669a6b20ae0f3251e8448bf.png

Causal Mask

尽管双向注意力很简单,但在论文《Attention is All You Need》,以及其他的 LLM 中,它们的设置都是仅解码器的注意力,其中每个 token 只能关注它之前的 token。如果用户使用 score_mod API ,可以将其表示为:

93930b0d6978c0232f516fd7ff7ee9c3.png

Sliding Window + Causal

c0fd34204f2d12cc089904051734401d.png

图源:https://arxiv.org/abs/2310.06825

Mistral 一直在推广滑动窗口注意力(也称为局部注意力),它允许查询 token 仅关注最近的 1024 个 token,通常与因果注意力一起使用。

ebbc0e551e2a1e0ca949770b6e034da7.png

研究者对带有滑动窗口掩码的 F.scaled_dot_product_attention 以及带有因果掩码的 FA2 进行基准测试。结果表明,FlexAttention 不仅明显快于 F.scaled_dot_product_attention,也明显快于带有因果掩码的 FA2。

0608536342bf3d6e2c4223c57b322ce7.png

性能

总体而言,FlexAttention 的性能几乎与手写的 Triton 内核一样好。然而,由于 FlexAttention 具有通用性,因此会遭受轻微的性能损失。例如,用户必须承受一些额外的延迟。

FlexAttention 在前向传播中实现了 FlashAttention2 性能的 90%,在反向传播中实现了 85%。FlexAttention 目前正在使用一种确定性算法,该算法比 FAv2 重新计算了更多的中间体,研究者计划改进 FlexAttention 的反向算法,来缩小这一差距!

15b986e114b2b7dc183a1c531bd27f68.png

9aa9ae6ec364b2ef0cdcb62fde1ea66c.png

参考链接:https://pytorch.org/blog/flexattention/

 
 

何恺明在MIT授课的课件PPT下载

 
 

在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!

ECCV 2024 论文和代码下载

在CVer公众号后台回复:ECCV2024,即可下载ECCV 2024论文和代码开源的论文合集

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

Mamba、多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信号:CVer111,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer111,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集上万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值