FIT: Far-reaching Interleaved Transformers

Paper name

FIT: Far-reaching Interleaved Transformers

Paper Reading Note

Paper URL: https://arxiv.org/pdf/2305.12689

Code URL: https://github.com/google-research/pix2seq

TL;DR

  • 2023 年 google deepmind 提出的 FIT 网络架构,具有高效的自注意力和自适应计算,可以作为编码器、扩散解码器或自回归解码器使用。FIT 能有效降低计算量的同时保持模型的精度。值得注意的是,FIT展示了在千兆级数据(如6400×6400图像或160K tokens(经过补丁标记化后))上进行端到端训练的潜力,在16GB内存容量下,无需特定优化或模型并行化。

Introduction

背景

  • Transformer 旨在处理一组数据 tokens,例如文本 tokens 或图像补丁 tokens。它采用自注意力机制,实现了tokens之间的全互联信息交换,导致 O(L²) 的复杂度。尽管 Transformers 在各个领域展示了成功,其二次复杂度在处理更长序列时却带来了限制。尽管已经有努力解决这一挑战,但全二次注意力机制仍然是最有效和常用的,特别是对于较短的序列。

本文方案

  • 为了利用二次注意力处理长序列,我们借鉴了自然数据可以分组组织的方式。例如,书中的文本 tokens 可以分成章节,而图像中的补丁 tokens 可以组织成块或窗口。在每个组内,我们可以使用二次注意力的高带宽通信通道,而组之间可以使用带有有意义压缩的低带宽通道。
  • 这种将数据 tokens 分组或分段的方法在若干现有工作中已成功应用。然而,协调局部(组内)和全局(组间)信息处理的机制仍然未得到充分探索。在这项工作中,我们精心设计了一种有效协调局部和全局处理的机制。
    • 首先,我们为每个组引入了一小组潜在 tokens。
    • 然后,我们交错使用两种 Transformer 层,一种用于使用局部/窗口注意力处理数据 tokens,另一种用于使用全局注意力处理潜在 tokens。交叉注意力用于在同一组内的数据 tokens 和潜在 tokens 之间传递信息。网络的单次前向传递涉及数据 tokens 和潜在 tokens 的迭代更新,确保局部和全局信息得到充分整合。
    • 我们在高分辨率图像理解(作为图像编码器)和生成(作为扩散解码器和自回归解码器)上评估了所提出的架构,并提供了初步证据,证明该架构可以作为高效且有效的 Transformers 扩展,用于处理和生成长序列。

Methods

我们简要概述了所提出架构中的关键概念,称为FIT 或 FitTransformer.

  • 分组 (Groups):

    • Transformers 在一组数据 tokens 上操作,通过位置编码管理 tokens 的顺序,而不是它们在计算机内存中的特定布局。Transformers 的输入表示为 x ∈ Rb×L×c,其中输入形状为(批量大小、tokens数量、token维度)。为了便于处理,我们将一组数据 tokens 分成多个组。这实际上将输入 x 重新组织为Rb×t×n×c,其中新形状表示(批量大小、组的数量、每组的 tokens 数量、token 维度),且 L = t × n。数据分组的过程是灵活的,可以通过直接拆分或将序列重塑为子序列来实现。对于图像而言,它涉及将图像分块为子图像,每个子图像被视为一个独立的组。
  • 数据(局部)tokens 与潜在(全局)tokens:

    • 在 FIT 的背景下,我们区分数据 tokens 和潜在 tokens。
      • 数据 tokens 对应于标准 Transformers 中使用的 tokens,通常与特定的数据元素相关联。例如,在图像的情况下,数据 token 可以表示一个补丁嵌入向量。即使在通过 Transformer 层进行转换后,数据 tokens 仍然与数据的特定部分相关联。
      • 另一方面,潜在 tokens 是一小组额外引入的 tokens,通常表示为位置嵌入,最初并不直接与基础数据相关联。然而,在前向传递期间,潜在 tokens 动态聚合信息并与数据的特定部分相关联。这个自适应过程因示例而异,使得模型能够形成长期记忆并有效压缩数据 tokens 中的信息。
  • 局部 Transformer 层与全局 Transformer 层:

    • 局部和全局 Transformer 层结构相似,均包含标准的自注意力模块和前馈网络。然而,它们在模型中处理不同的 tokens 集合。
      • 局部 Transformer 层应用于每个组内的数据 tokens。这些层处理各自组内的数据 tokens,允许局部信息处理并捕捉组内 tokens 之间的细粒度关系。值得注意的是,可以通过用其他架构构件(如卷积)替换局部 Transformer 层,或通过去除自注意力来简化它们。
      • 另一方面,全局 Transformer 层负责处理跨所有组的潜在 tokens。这些层使模型能够捕捉输入不同部分之间的全局依赖关系和长距离关系。

FIT 架构

  • 下图为 FIT 架构,处理形状为Rb×t×n×c的数据tokens和形状为Rb×t×m×d的潜在tokens,其中m ≪ n。
    在这里插入图片描述
  • FIT 实现伪代码
    在这里插入图片描述

扩展基本FIT用于自回归建模

  • 在自回归(语言)建模中,防止未来数据 tokens 的信息流入过去是至关重要的。可以通过在全局 Transformer 层采用分块因果掩码来轻松实现这一要求,这允许组内潜在 tokens 之间完全可见,但在组之间强加因果掩码。然而,在基本的 FIT 架构中,存在信息可能在同一组内从未来 tokens 无意中泄漏到过去的问题。为了解决这个问题,我们引入了在推送和拉取信息之间的潜在 tokens 的偏移概念。具体来说,当第 i 组的数据 tokens 向同组的潜在 tokens 推送信息时,它们需要从第 (i-1) 组的潜在 tokens 拉取信息。这种偏移机制确保了信息流动的一致性和因果性,防止任何未来信息无意中泄漏到过去。这在图 2 中有所说明,FIT-AR 架构的训练伪代码在算法 2 中给出。在推理方面,模型仍然以自回归方式一次解码一个 token,但 FIT 中存在的潜在 tokens 总结了前面的数据 tokens,这可以显著提高长序列的解码速度,同时减少内存使用。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

复杂度与效率分析

  • FIT 相比于标准 Transformers 在效率上提供了两个显著提升:首先,通过交替使用局部和全局注意力,它显著降低了注意力层的复杂度,从 O(L²) 的二次复杂度降低到最优复杂度 O(L^(4/3))。其次,该架构支持自适应计算。通过将局部 Transformer 层的处理转移到全局 Transformer 层(其操作在一组较小的自适应潜在 tokens 上),整体计算成本进一步降低。这些效率提升使所提出的架构非常适合处理长序列,同时保持计算的可处理性。表 1 分解了标准 Transformers 和 FIT 的计算成本。
    在这里插入图片描述

  • 图 3 展示了一个仅解码 Transformer 模型的案例研究,重点分析了浮点运算次数(FLOPs)。该模型大约有 130 亿参数,包括 40 层,隐藏维度为 5120(对于数据和潜在 tokens)。组/窗口大小固定为 2048,对于所有序列长度,这导致理论上的注意力复杂度为 O(L²)。然而,对于每组 64 个潜在 tokens,全局注意力仅作用于约 3% 的数据 tokens。
    在这里插入图片描述

  • 在图 4a 中,我们观察到用窗口/组注意力替换标准 Transformers 的完全注意力可以显著减少长序列的 FLOPs。然而,在这种情况下,组之间没有全局交互。通过 FIT,我们引入了全局 Transformer 层,以实现跨组交互。值得注意的是,由于潜在 tokens 数量减少,全局 Transformer 层所需的额外 FLOPs 相对较少,即使对于 100 万个 tokens 的序列长度也是如此。此外,图 4b 显示了局部层的计算相比全局层要昂贵得多。因此,通过将计算从局部层转移到全局层,可以进一步减少 FLOPs。
    在这里插入图片描述

Experiments

在三个任务上评估了所提出的架构:
1)通过 Pix2Seq 对象检测在 object365 上的高分辨率图像理解
2)通过基于像素的去噪扩散模型在 512×512 或 1024×1024 分辨率的Imagenet 上的高分辨率图像生成
3)在 Imagenet-64×64 上的基于像素的自回归图像生成

使用Pix2Seq进行高分辨率图像理解

  • 使用预训练的负对数似然(Nll)评估不同编码器的性能。我们将图像分为16个子图像,每个子图像作为一个组,并为每组分配32个潜在标记。为了与ViT进行比较,我们保持相似的架构,但增加了一些全局层。因此,标准的ViT层现在对应于在每个组内独立运行的局部层。

    • 对于FIT-B,我们有L(4)G(2)L(4)G(2)L(4)层,而对于FIT-L,我们有L(6)G(2)L(6)G(2)L(6)G(2)L(6)层,其中L/G分别表示局部和全局层
    • 实验主要关注640 × 640的图像
  • FIT不仅在训练过程中增加了每秒步骤数,还降低了损失
    在这里插入图片描述

基于像素的端到端扩散建模

  • RIN 是最近在去噪扩散模型的架构设计和建模方面的进展,已经展示了直接在高达 1024×1024 分辨率的高分辨率图像上训练的能力。如前所述,RIN 可以视为基本 FIT 架构的一个特例,具有一组 token 且局部层中没有自注意力

    • 使用预测和目标之间的均方误差(MSE)
  • 将 RIN 转变为 FIT,我们观察到均方误差(MSE)显著降低,训练速度(在 TPUv3 上以每秒步骤数衡量)显著提高
    在这里插入图片描述

基于像素的自回归图像建模

  • 以自回归方式直接将像素建模为离散 token,由于序列长度较长(例如, 64x64 图像产生 12,288 个数据 token)且需要捕捉局部和全局依赖关系,因此面临挑战。在我们的方法中,我们将像素局部分组为 8×8 的块,每组产生 192 个数据 token,并为每组使用 32 个潜在 token。我们使用 512 维的数据 token 和 768 维的潜在 token,层配置为 L(8)G(2)L(8)G(2)L(8)G(2)L(8)。表 4 中的总结结果表明,考虑到模型的规模,性能接近最先进的水平。
    在这里插入图片描述

消融实验

latent number 影响
  • 使用更多潜在 token 效果会更好。增加潜在 token 的数量并不一定显著增加参数数量或训练时间(每秒步骤数),特别是在局部层对计算贡献较多时
    在这里插入图片描述
层交错模式的影响
  • 探讨了在保持局部和全局层数量不变的情况下(除了仅包含局部层的情况,该情况对应于原始的局部和全局层),各种层交错模式的效果,并且我们对两种类型的层使用相同的隐藏维度。值得注意的是,我们观察到,交错使用局部和全局层对于实现最佳结果至关重要,同时保持大致相同的训练效率(在TPUv3上以每秒步骤数衡量)。
    在这里插入图片描述

Thoughts

  • FIT 整体的思路:
    • 连接在不同的数据 token group 上独立操作的局部 Transformer 与提供上下文反馈的全局 Transformer
    • 通过一组潜在 token 来实现全局 Transformer 的可学习 token,这些潜在 token 有选择性地关注数据 token,从而实现更自适应的计算
  • 感觉实验中缺少对变长输入的建模能力分析,比如图像、视频联合训练时用同样的潜在 token 空间来同时进行表征可能不是很科学,如果在这种变长输入下挑选合适的 latent dim 还需要一些实验验证。
  • 16
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值