ViR: Towards Efficient Vision Retention Backbones
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
0. 摘要
视觉变换器(Vision Transformers,ViT)近年来受到了很大的欢迎,这归因于它们在建模长程空间依赖性和大规模训练的卓越能力。尽管自注意力机制的训练并行性在保持出色性能方面发挥着重要作用,但其二次复杂度使 ViTs 在许多需要快速推理的场景中难以应用。在需要输入特征的自回归建模的应用中,这种效应甚至更加显著。在自然语言处理(NLP)领域,一系列新的努力提出了可以并行化的模型:具有循环公式(recurrent formulation),可以在生成应用中进行高效推理。在受到这一趋势启发的情况下,我们提出了一类新的计算机视觉模型,被称为 Vision Retention Networks(ViR),具有双重并行和循环公式,它在快速推理和并行训练之间取得了最佳平衡,并具有竞争性的性能。特别是,ViR 在处理大序列长度的灵活公式中,对于需要更高分辨率图像的任务,在图像吞吐量和内存消耗方面表现出色。ViR 是在通用视觉骨干中实现双重并行和循环等效性的首次尝试,用于识别任务。我们通过对不同数据集大小和各种图像分辨率进行了大量实验证实了 ViR 的有效性,并取得了竞争性的性能。
代码:https://github.com/NVlabs/ViR
(注:ICLR 2024 拒稿)
OpenReview:https://openreview.net/forum?id=cmAIfTK6fe
2. 相关工作
自注意力的替代方案。为了解决自注意力的二次计算复杂性,许多努力提出了各种方法,例如对softmax 激活函数的近似 [14, 23],通过使用其他核心进行线性注意力 [24, 45] 以估计注意力分数,或在通道特征空间中计算注意力 [1]。然而,提高效率会对模型的性能产生负面影响。其他工作 [18, 55] 还提出了完全用其他机制替代自注意力的方法。特别是,在最近的自然语言处理中,RWKV [33] 和 RetNet [38] 提出重新定义 Transformer 以利用并行和循环公式的二元性进行训练和推理。RWKV 采用无注意力的公式 [55],但使用指数衰减来启用循环公式。RetNet 提出使用多尺度门控保留(multi-scale gated retention)来保持上下文信息的表达能力并取得竞争性能。虽然我们的工作受到了 RetNet 的启发,但它针对的是计算机视觉,特别是识别,具有定制的保留机制和架构重新设计以实现最佳性能。
3. 方法
3.1. 1D 保留
在本节中,我们讨论保留机制及其不同的公式 [38]。考虑一个输入序列 X ∈ R^(|X| x D),将以自回归方式进行编码。给定状态为 s_n 的查询(q_n)、键(k_n)和值(v_n),这个序列到序列的映射可以写成:
其中,Ret 和 γ 分别表示保留和衰减因子。实质上,s_n 方便地保持了先前的内部状态。正如 [38]所示,保留也可以在并行公式中定义:
其中,M 表示一个具有衰减因子 γ 的掩码,如
保留在并行和循环模式中的双重表示使得许多期望的特性成为可能,如训练并行性和快速推理。对于更长的序列,循环模式可能变得效率低下。因此,一种混合方法,称为分块(chunkwise),结合了循环和并行公式,是可取的。具体而言,输入 X 被分割成具有块大小 C 的较小序列,其中
表示第 m 个块。分块查询、键和值可以定义为
分块保留的公式如下
分块公式的基本动机是在处理横跨块表示时,在每个块中采用并行模式同时使用循环模式。对于具有长序列的高分辨率图像,分块公式允许更快地处理标记并解耦内存。在第 5.3 节中,我们演示了ViRs 如何由于分块公式而与 ViTs 相比更有利于对较长序列的高效处理。
3.2. 2D 保留
我们进一步扩展 1D 公式以实现平移等变性。在 1D 公式下,沿图像列的相邻块之间的衰减增加了一个称为 W 的因子,该因子是图像每行的块数。我们的 2D 公式旨在保持相邻水平和垂直位置之间的衰减。
3.2.1 2D 循环公式
给定一个点 (x, y),我们重新以函数形式 r(x, y) 表示方程 1,以便使用 x 和 y 坐标对序列中的位置进行参数化,其中 x、y 属于正整数集 。我们将其表述为
我们采用位置 (p+f, y+g) 和 (x, y) 之间的 L1 距离作为衰减率,获得
我们保留 retention 的自回归属性,因此强制 f, g ≥ 0。此外,我们导出了 2D retention 在循环形式中的公式,如下所示:
方程 8 的前 3 项可以看作递归中的基本情况。实际上,r(x, 1) 和 r(1, y) 的形式与原始的保留公式相同。在我们介绍 2D 保留的并行形式时,下一节将更清楚地解释广义形式 r(x, y) 的直觉。至关重要的是,这个形式仍然允许以常数时间复杂度计算 r(x, y),因为它计算的是一定数量的项的总和(r(x-1, y), r(x, y-1), r(x-1, y-1))。
3.2.2 2D 并行公式
为了方便表示,让 Δx 等于 x-f,Δy 等于y-g,其中 f 不大于 x,g 不大于 y,而 x、y、f、g 属于正整数集。在此基础上,我们引入并行公式:
在并行公式中,L1 距离如何支撑衰减率也变得更加明显,因为它直接应用在并行公式中。请参阅补充材料,了解并行和循环公式之间等价性的证明。
为了构建并行公式的完整衰减掩码,我们引入了完整的标记序列 s ∈ S,以及在其中的位置,然后x'(s) = s mod W,y'(s) = s mod W。因此,Δx' 等于 x'(c) - x'(r),Δy' 等于 y'(c) - y'(r)。因此,该掩码表示为
3.3. ViR 模型
在接下来的部分,我们首先讨论各向同性的 ViR 模型。此外,我们介绍了混合 ViR,它由 CNN 和基于 retention 的层组成,其中包含了诸如局部性和权重共享等归纳偏差,可以提高训练和数据效率。
3.4. 各向同性
图 2 展示了我们提出的模型的概述。给定一个输入图像 X 属于 R^(HxWxC),其中 H 和 W 分别为高度和宽度,将其划分为块,并展平成一个 token 序列。这类似于之前由 ViT [13] 提出的分词化。分词化的块然后被投影到具有维度 D 的块嵌入 Z = [z_1, …… , z_|z|] ∈ R^(|z| x D)。与 ViT 不同的是,我们首先将位置嵌入添加到块嵌入中,然后附加一个 [class] 标记 (Z^0_n = X_class)。
具有 L 层(Z^n_L)的 ViR 编码器的输出,在预训练和微调期间都用于一个分类多层感知机(MLP)头。由于 ViR 模型的自回归性质,[class] 的位置起着重要作用,因为将其附加到嵌入序列的末尾相当于对所有先前标记进行汇总。
我们使用保留而不是自注意力来通过掩码实施循环公式。然而,我们的公式不依赖于门控保留或特定的相对位置嵌入(例如 xPos [37] 或 RoPE [36])并在并行、循环和混合(即局部循环和全局并行的混合)公式之间实现了数值等价性。具体而言, 并行保留公式仅取决于查询 q,键 k,值 v 和衰减掩码 M,定义如下
其中,Ret 代表保留,D_h 是一个用于平衡计算和参数数量的缩放因子。此外,原始的保留公式,如在 RetNet [38] 中提出的,由于添加了可学习的门控函数,增加了参数数量,导致在相同的网络布局下降低了图像吞吐量。
保留(Ret)进一步扩展为多头保留(Multi-Head Retention,MHR)。保留是通过每个头部计算的,具有恒定的衰减因子,并根据 LayerNorm [2](LN)进行归一化,如下所示:
我们根据以下方式使用交替的 MHR 和 MLP 块,其中包括 Layer-Norm(LN)和残差连接,作为编码器的构建基块:
3.5. 混合
混合 ViR(HViR)具有多尺度架构,共有四个具有不同分辨率的阶段。较高分辨率的特征在前两个阶段中进行处理,这包括具有残差连接的基于 CNN 的块。具体而言,给定输入 h,它被定义为
其中,Conv_(3x3) 是一个密集的 3x3 卷积层,BN表示批归一化 [22]。较低分辨率的阶段包括与第3.4 节中描述的相似的保留块。有关不同 HViR 模型变体的体系结构细节,请参阅补充材料。
4. 实验
5. 消融
5.4. 保留看到了什么?
在图 6 中,我们展示了从经过 ImageNet-1K 预训练的 ViR-S/16 模型获得的保留图。具体而言,这些保留图是从编码器的最后一层提取的,没有使用任何后处理或标准化层。我们观察到,高强度响应区域对应着显著的图像特征。对于细长的物体,长程空间依赖性已经被有效地捕捉到。我们观察到在其他经过 ImageNet-1K 和 ImageNet-21K 数据集训练的 ViR 变体中存在类似的趋势。