MambaOut: Do We Really Need Mamba for Vision?
论文链接:http://arxiv.org/abs/2405.07992
代码链接:https://github.com/yuweihao/MambaOut
Mamba真的没有用吗?
文中发现Mamba对于ImageNet的图像分类任务并非必需,而在其他下游任务(目标检测和语义分割等)上对模型性能是有提升的。
1、摘要
本文深入探讨了Mamba的本质,并在概念上得出结论,即Mamba非常适合具有长序列和自回归特性的任务。对于视觉任务,由于图像分类既不符合长序列特性也不符合自回归特性,假设Mamba对于这项任务并非必要;检测和分割任务也不是自回归的,但它们符合长序列特性,因此认为它们仍然值得探索Mamba在这些任务中的潜力。为了从经验上验证该假设,文中构建了一系列名为MambaOut的模型,通过堆叠Mamba块并移除它们的核心token混合器SSM。实验结果强有力地支持我们的假设。具体而言,MambaOut模型在ImageNet图像分类上超越了所有视觉Mamba模型,表明Mamba在这项任务中确实是不必要的。至于检测和分割,MambaOut无法与最先进的视觉Mamba模型的性能匹敌,展示了Mamba在长序列视觉任务中的潜力。
2、创新点
-
分析了SSM的类似RNN机制, 并理论上得出结论,Mamba适合处理具有长序列和自回归性质的任务。
-
研究了视觉任务的特点, 提出在ImageNet图像分类任务中, 由于不具备这两种特性,SSM可能并非必需。 然而,探索SSM在检测和分割任务中的潜力仍然有价值,尽管它们是非自回归的,但这些任务符合长序列特性。
-
基于Gated CNN模块构建了MambaOut系列模型,但不包含SSM。实验表明,MambaOut在ImageNet图像分类中有效超越了视觉Mamba模型,但在检测和分割任务上未达到最先进的性能。这些观察结果反过来验证了假设2。因此,由于MambaOut的奥卡姆剃刀原则,它可能成为未来视觉Mamba模型研究的自然基线。
3、原理
Mamba特别适合具有两个关键特性的任务:长序列和自回归,因为SSM的内在RNN机制(见图2和图3)。
但很少有视觉任务同时具备这两个特性。例如,ImageNet上的图像分类既不符合长序列也不符合自回归。而在COCO上的对象检测与实例分割和ADE20K上的语义分割只符合长序列特性。
另一方面,自回归特性要求每个令牌仅从前面和当前令牌中聚合信息,这被称为令牌混合的因果模式([62])(见图3(a))。实际上,所有视觉识别任务都属于理解领域,而非生成领域,这意味着模型可以一次性看到整个图像。因此,在视觉识别模型中强加额外的因果约束可能会导致性能下降(见图3(b))。提出以下两个假设:(Hypothesis 1) 和 (Hypothesis 2)。
- 假设1:对于图像分类,SSM(自回归序列模型)并非必需,因为这项任务既不符合长序列(long-sequence)特征,也不符合自回归(autoregressive)特征。(SSM: 自回归序列模型)
- 假设2:由于遵循了长序列特性(尽管非自回归),SSM可能对目标检测与实例分割(Object Detection & Instance Segmentation)以及语义分割有益。(SSM:Sequence-to-Sequence Model)
为验证以上两个假设,作者设计了一系列模型, 称为MambaOut, 通过堆叠Gated CNN([18]) 模块。 Gated CNN和Mamba模块的关键区别在于存在自注意力机制 (SSM), 如图1(a)所示。
实际上,更简单的MambaOut模型已经超越了视觉Mamba模型(如[102, 49,37, 86])的性能,这验证了假设1。同时观察到MambaOut在目标检测和分割任务(见表2和3)中的表现不如最先进的视觉Mamba模型,这突显了SSM在这些任务中的潜力,进一步验证了假设2。
Conceptual discussion
探讨Mamba模型适用于哪些任务的特性,以及检验视觉识别任务是否符合这些特性。
What tasks is Mamba suitable for?
Mamba的token混合器是选择性自注意力(selective self-attention, SSM)[26, 25],它定义了四个输入相关的参数(
D
e
l
t
a
Delta
Delta,
A
A
A,
B
B
B,
C
C
C),并通过变换将它们转化为(
A
ˉ
\bar{A}
Aˉ,
B
ˉ
\bar{B}
Bˉ,
C
C
C):
A
ˉ
=
e
x
p
(
Δ
A
)
,
B
ˉ
=
(
Δ
A
)
−
1
(
e
x
p
(
Δ
A
)
−
I
)
⋅
Δ
B
.
(
3
)
\bar{A} = exp(\Delta A), \ \bar{B} = {(\Delta A)}^{-1}(exp(\Delta A) - I) \cdot \Delta B. \ (3)
Aˉ=exp(ΔA), Bˉ=(ΔA)−1(exp(ΔA)−I)⋅ΔB. (3)
然后,序列到序列的SSM转换可以表示为:
h
t
=
A
ˉ
h
t
−
1
+
B
ˉ
x
t
,
(
4
)
h_{t} = \bar{A} h_{t-1} + \bar{B} x_{t}, \ (4)
ht=Aˉht−1+Bˉxt, (4)
y
t
=
C
h
t
.
(
5
)
y_{t} = C h_{t}. \ (5)
yt=Cht. (5)
其中,
t
t
t表示时间步,
x
t
x_{t}
xt代表输入,
h
t
h_{t}
ht表示隐藏状态,
y
t
y_{t}
yt表示输出。方程2的递归性质[34]将 RNN 类似 SSM 与因果注意力区分开来。隐藏状态
h
h
h可以看作是一个固定大小的记忆,存储所有历史信息。通过方程2,这个记忆在保持大小不变的同时进行更新。固定大小意味着记忆不可避免地存在信息丢失,但确保了将记忆与当前输入融合的计算复杂度保持不变。相反,因果注意力将所有先前token的键和值作为其记忆,随着每个新输入,记忆会通过添加当前token的键和值而扩展。理论上,这种记忆是无损的。然而,随着更多token输入,记忆大小增加,与当前输入融合的复杂性也随之上升。RNN 类似模型和因果注意力之间的记忆机制差异在图2中有更直观的展示。
由于SSM 的记忆本质上是丢失的,它在处理短序列方面自然不如注意力的无损记忆。因此,Mamba 在处理短序列时无法充分发挥其优势,而注意力在这个领域表现得游刃有余。然而,在长序列处理中,注意力由于其二次复杂度而面临困难。在这种情况下,Mamba 显著提高了将记忆与当前输入融合的效率,从而平稳地管理长序列。因此,Mamba 特别适合处理长序列。
尽管SSM 的递归性质(方程2)使得Mamba 能够高效处理长序列,但也引入了一个重大局限:
h
t
h_{t}
ht 只能访问前一时间和当前时间的信息。如图 3 所示,这种类型的tokrn混合称为因果模式,可以表述为:
y
t
=
f
(
x
1
,
x
2
,
.
.
.
,
x
t
)
.
y_{t} = f(x_{1}, x_{2}, ..., x_{t}).
yt=f(x1,x2,...,xt).
其中,
x
t
x_{t}
xt 和
y
t
y_{t}
yt 分别表示第
t
t
t个令牌的输入和输出。由于其自回归的特性,这种模式非常适合自回归(Autoregressive)生成任务。
另一种模式称为全可见模式,其中每个token可以聚合所有先前和后续token的信息。这意味着每个token的输出都依赖于所有token的输入:
y
t
=
f
(
x
1
,
x
2
,
.
.
.
,
x
T
)
,
y_{t} = f(x_{1}, x_{2}, ..., x_{T}),
yt=f(x1,x2,...,xT),
其中,
T
T
T表示总tokens数量。全可见模式适用于理解任务,此时模型可以一次性访问所有输入。
注意力默认处于全可见模式,但通过在注意力映射上应用因果掩码,可以轻松将其转换为因果模式。类似于RNN的模型由于其递归特性,自然地以因果模式运行,如Mamba的公式2所示。由于这种固有特性,RNN类模型无法转变为全可见模式。
尽管RNNs可以通过双向分支近似全可见模式,但每个分支仍然保持因果模式。因此,由于其递归性质的固有限制,Mamba特别适合那些需要因果token混合的任务。
总之,Mamba最适合具有以下特性的任务:
- 特征1:任务涉及处理长序列(Long Sequence Processing)
- 特征2:任务需要因果token 混合模式(Causal Token Mixing)
接下来,讨论视觉识别任务是否具备这两个特性。
Do visual recognition tasks have very long sequences?
探讨视觉识别任务是否真的需要长序列建模。以Transformer模型([75])作为案例研究,以便进行分析。假设Transformer块具有常见的MLP (多层感知器)比为4;如果其输入
X
∈
R
L
×
D
X \in R^{L \times D}
X∈RL×D的序列长度为
L
L
L,通道(嵌入)维度为
D
D
D,那么该块的计算量(FLOPs)可以这样计算:
F
L
O
P
s
=
24
D
2
L
+
4
D
L
2
.
(
6
)
FLOPs = 24D^{2}L + 4 D L^{2}.\ (6)
FLOPs=24D2L+4DL2. (6)
由此推导出
L
L
L中二次项与线性项的比例为:
r
L
=
4
D
L
2
24
D
2
L
=
L
6
D
.
(
7
)
r_{L} = \frac{4DL^{2}}{24D^{2}L} = \frac{L}{6D}. \ (7)
rL=24D2L4DL2=6DL. (7)
如果 L > 6 D L > 6D L>6D,二次项的计算负担超过了线性项,这提供了一个简单的指标来判断任务是否涉及长序列。例如,在ViT-S中,当有384个通道时,阈值 τ s m a l l = 6 × 384 = 2304 \tau_{small} = 6 \times 384 = 2304 τsmall=6×384=2304,而在ViT-B中,当有768个通道时, τ b a s e = 6 × 768 = 4608 \tau_{base} = 6 \times 768 = 4608 τbase=6×768=4608。
对于ImageNet上的图像分类,典型输入图像大小为 22 4 2 224^{2} 2242,使用 1 6 2 16^{2} 162的patch大小,这将产生 1 4 2 = 196 14^{2} = 196 142=196个token。显然,196远小于 τ s m a l l \tau_{small} τsmall和 τ b a s e \tau_{base} τbase,表明ImageNet上的图像分类不被视为长序列任务。对于COCO上的对象检测与实例分割,其推理图像大小为 800 × 1280 800 \times 1280 800×1280,而在ADE20K上的语义分割,推理图像大小为 512 × 2048 512 \times 2048 512×2048,使用 1 6 2 16^{2} 162的patch大小,大约会产生4K个token。由于4K大于 τ s m a l l \tau_{small} τsmall且接近 τ b a s e \tau_{base} τbase,因此在COCO上的检测和ADE20K上的分割都可以视为长序列任务。
Do visual recognition tasks need causal token mixing mode?
如前所述并如图3所示,全可见的token混合模式允许无限制的混合范围,而因果模式则限制当前token只能访问前导令牌的信息。视觉识别被归类为理解任务,模型可以一次性看到整个图像,因此无需对token混合施加限制。对token混合施加额外约束可能会降低模型性能。如图3(b)所示,当将因果限制应用于视觉Transformer(ViT,源自[23, 72])时,性能明显下降。
通常,全可见模式适用于理解任务,而因果模式更适合自回归任务。这一观点也得到了BERT ([20])和ViT (BEiT[4]、MAE[30])更多用于理解任务,而GPT-1/2[59, 60]和图像GPT[9]主要用于生成任务的观察所支持。因此,视觉识别任务不需要因果的token混合模式。
Hypotheses regarding the necessity of Mamba for vision
根据之前的讨论,作者总结了关于在视觉识别任务中引入Mamba的假设如下:(Mamba Hypothesis)
- 假设1:在ImageNet图像分类任务中,无需引入自注意力机制(SSM),因为该任务不满足特性1或特性2。
- 假设2:尽管不满足特性1,但视觉检测和分割任务与特性2相契合,因此继续探索SSM (选择性自注意力)在这些任务中的潜力仍然值得 (Selective Self-Attention)。
Gated CNN and MambaOut
旨在通过实证来验证上述假设。如图1(a)所示,Mamba块基于[25] Gated CNN块[18]。Gated CNN和Mamba的元架构可以看作是MetaFormer[89]的toekn混合器和MLP的简化整合,类似于MetaNeXt[91]。形式化地,给定输入
X
∈
R
N
×
D
X \in R^{N \times D}
X∈RN×D,元架构可以表示为:
X
′
=
N
o
r
m
(
X
)
,
(
8
)
X^{'} = Norm(X), \ (8)
X′=Norm(X), (8)
Y
=
(
T
o
k
e
n
M
i
x
e
r
(
X
′
W
1
)
⋅
(
X
′
W
2
)
)
W
3
+
X
,
(
9
)
Y = (TokenMixer(X^{'}W_{1}) \cdot (X^{'} W_{2}))W_{3} + X, \ (9)
Y=(TokenMixer(X′W1)⋅(X′W2))W3+X, (9)
其中,
N
o
r
m
(
⋅
)
Norm(\cdot)
Norm(⋅) 表示归一化操作([38, 2, 82]);
T
o
k
e
n
M
i
x
e
r
(
⋅
)
TokenMixer(\cdot)
TokenMixer(⋅) 指的是执行token混合的模块([90]);
W
1
∈
R
D
×
r
D
W_{1} \in R^{D \times rD}
W1∈RD×rD,
W
2
∈
R
D
×
r
D
W_{2} \in R^{D \times rD}
W2∈RD×rD 和
W
3
∈
R
r
D
×
D
W_{3} \in R^{rD \times D}
W3∈RrD×D 是具有MLP 扩展系数
r
r
r 的可学习参数;
σ
\sigma
σ 是激活函数([24, 33])。Gated CNN 和Mamba 的token混合器为:
T
o
k
e
n
M
i
x
e
r
G
r
a
t
e
d
C
N
N
(
Z
)
=
C
o
n
v
(
Z
)
,
(
10
)
TokenMixer_{GratedCNN} (Z) = Conv(Z) ,\ (10)
TokenMixerGratedCNN(Z)=Conv(Z), (10)
T
o
k
e
n
M
i
x
e
r
M
a
m
b
a
(
Z
)
=
S
S
M
(
σ
(
C
o
n
v
(
Z
)
)
)
.
(
11
)
TokenMixer_{Mamba}(Z) = SSM(\sigma(Conv(Z))) .\ (11)
TokenMixerMamba(Z)=SSM(σ(Conv(Z))). (11)
比较公式10和11,并且参考图1(a),Gated CNN([59])与Mamba块([25])的主要区别在于是否存在SSM。这促使作者开发了一系列基于Gated CNN块但不包含SSM的模型,称之为MambaOut。这将帮助评估Mamba在视觉识别任务中的必要性。
具体来说,将Gated CNN的token mixer指定为
7
×
7
7 \times 7
7×7内核大小的深度卷积,遵循ConvNeXt([51, 54])。为了提高实际速度,仅在部分通道上进行深度卷积,这类似于ShuffleNet([53])和InceptionNeXt ([91])。如算法1所示,Gated CNN块的实现简洁高效。类似于ResNet,采用4阶段框架,每一阶段堆叠Gated CNN块,如图4所示。各模型大小的详细配置见附录中的表4。
4、实验
Image classification on ImageNet
实验设置:ImageNet [19, 65] 是图像分类的黄金标准基准,包含1000个常用类别。它包括大约130万张训练图像和5万张验证图像。我们的训练方案遵循DeiT [72],不使用知识蒸馏。具体来说,我们使用的数据增强包括随机缩放裁剪(输入图像大小为2242)、水平翻转、RandAugment[15]、Mixup[98]、CutMix [94]、随机擦除 [100] 和色彩抖动,以及正则化技术包括权重衰减、深度学习中的随机深度[36] 和标签平滑[70]。所有模型使用AdamW [52,41] 进行训练。学习率规则为
l
r
=
b
a
t
c
h
s
i
z
e
1024
×
1
0
−
3
lr = \frac{batchsize}{1024} \times 10^{−3}
lr=1024batchsize×10−3。本文设置批量大小为4096,因此学习率为0.004。MambaOut模型使用PyTorch [55] 和timm [80] 库实现,并在TPU v3上进行训练。更多的训练超参数在附录表5中展示。
结果:表1显示了MambaOut模型、视觉Mamba模型和其他卷积和注意力模型在ImageNet [19,65] 上的性能。值得注意的是,不包含SSM的MambaOut模型在所有模型规模上都持续优于包含SSM的视觉Mamba模型,如Zhu et al. (2024) [102], Liu et al. (2024) [49], Huang et al. (2024)
[37], Pei et al. (2024) [57] 和Yang et al. (2024) [86]。例如,MambaOut-Small模型的Top-1精度达到84.1%,比LocalVMamba-S [37] 高出0.4%,同时只消耗79%的MACs。这些结果有力地支持了假设1,即在ImageNet的图像分类中引入SSM是不必要的,这符合奥卡姆剃刀的原则。
此外,视觉Mamba模型目前与最先进的卷积和注意力模型存在显著性能差距。例如,CAFormer-M36 [90],它使用传统的Token混合器,如简单的分离卷积 [66] 和标准注意力机制 [75],在精度上比所有同类大小的视觉Mamba模型高出超过1%。如果未来的研究旨在挑战假设1,则需要开发结合卷积和SSM的视觉Mamba模型,以在ImageNet上实现最先进的性能。
Object detection & instance segmentation on COCO
实验设置:COCO 2017 [47] 是广泛认可的物体检测和实例分割基准。在实验中,MambaOut 作为Mask R-CNN [31] 的基础架构,使用在ImageNet上预训练的权重初始化。我们遵循标准的
1
×
1 \times
1×训练计划,共 12 个 epoch。训练图像的大小调整为短边为 800 像素,长边不超过1333像素。使用 AdamW [52,41] 优化器,学习率设为 0.0001,总批次大小为 16。实现基于PyTorch [55] 和mmdetection [8] 库。利用FP16 精度来节省训练成本。实验在4 块NVIDIA 4090 GPU 上进行。
结果:尽管MambaOut在COCO上的目标检测和实例分割方面可以超越一些视觉Mamba模型[57, 86],但仍落后于诸如VMamba[49]和LocalVMamba[49]等最先进的视觉Mamba模型。例如,作为Mask R-CNN的骨干网络,MambaOut-Tiny的性能比VMamba-T[49]低 1.4 A P b 1.4 AP^{b} 1.4APb和 1.1 A P m 1.1 AP^{m} 1.1APm。这种性能差距突显了在长序列视觉任务中集成Mamba的好处,进一步支持假设2。然而,与最先进的卷积-注意力混合模型TransNeXt[68]相比,视觉Mamba仍存在显著的性能差距。视觉Mamba需要通过在视觉检测任务中胜过其他最先进模型来进一步验证其有效性。
Semantic segmentation on ADE20K
实验设置:ADE20K [101](ADE20K),一个广泛用于语义分割任务的基准,包含了150个语义类别。它包含训练集20,000张图像和验证集2,000张图像。在我们的实验中,Mamba被用作UperNet [84]的后处理,从ImageNet预训练权重初始化。使用AdamW优化器 [41, 52]进行训练,学习率设为0.0001,批量大小为16,训练160,000个迭代。实现基于PyTorch[55]和mmsegmentation [14]库。实验在四块NVIDIA 4090 GPU上进行,使用FP16精度以提高训练速度。
结果:ADE20K上的语义分割性能趋势与COCO上的对象检测相似。MambaOut能超过一些视觉Mamba模型,但无法与最先进的Mamba模型相媲美。例如,LocalVMamba-T [37]在单尺度(SS)和多尺度(MS)评估中分别比MambaOut-Tiny高出0.5 mIoU,这进一步证实了假设2。此外,视觉Mamba模型与整合卷积和注意力机制的更高级混合模型(如SG-Former[64]和TransNeXt [68])相比,仍然表现出明显的性能差距。视觉Mamba需要在视觉分割任务中展现出更强的长序列建模能力。
5、总结
论文从概念上探讨了Mamba机制,并认为它非常适合具有长序列和自回归特性的任务。分析了常见的视觉任务,认为在ImageNet图像分类中引入Mamba是不必要的,因为它既不符合长序列特性,也不符合自回归特性。然而,Mamba在视觉检测和分割任务中,由于与长序列特性相符,其潜力值得进一步研究。为了实证这一观点,开发了MambaOut模型,这些模型使用Mamba块但不包含核心的token混合器SSM。MambaOut在ImageNet上的表现超越了所有视觉Mamba模型,但与最先进的视觉Mamba模型相比存在明显性能差距,这验证了文中的论点。