【模型架构】学习最火热的Mamba、Vision Mamba、MambaOut模型

一、Mamba

论文链接:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

代码链接:https://github.com/state-spaces/mamba

作者:Albert Gu,Tri Dao

发表单位:卡内基梅隆大学、普林斯顿大学

会议/期刊:暂无

Mamba的提出起源于RNN和Transformer本身存在的问题。

RNN的训练过程中当前时间步依赖于前一时间步的计算,因此不能并行计算,效率非常低,而结构并不复杂,所以推理速度还可以(线性计算);Transformer训练过程是矩阵运算,其训练是可以并行计算的,效率比较高,但是推理过程是一个词一个词去进行矩阵运算(即已经生成了一些token,当生成下一个token时,仍然需要重新计算整个序列的注意力),效率比较低。

那么,能不能提出一个训练和推理过程效率都很高的模型呢?这就有了Mamba。

Mamba是SSM(Structured State Space for Sequence Modeling,序列的结构化状态空间,因为有4个S,所以也称为S4)的改进,所以首先要介绍一下到底什么是SSM?

1.1 SSM的介绍

状态空间模型(State Space Model, SSM)是一种用于描述动态系统的数学模型,特别适用于时间序列分析和控制系统设计。它将系统的状态表示为一个状态向量,并通过状态方程和观测方程描述系统的动态行为和观测过程。

因此,SSM是可以用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型,这就符合了作为深度学习模型基础架构的条件。

SSM的计算示意图

具体来说,可以用下面的公式描述上述过程:

状态变量:描述系统当前状态的变量。状态变量通常是一个向量,包含系统当前时刻的所有信息。

状态方程:描述系统状态如何随时间变化,t+1时刻的状态变化通常形式为:

\mathbf{x}_{t+1}=\mathbf{A}\mathbf{x}_t+\mathbf{B}\mathbf{u}_t+\mathbf{w}_t

其中,xt 是时刻 t 的状态向量,A 是状态转移矩阵,B 是控制输入矩阵,ut​ 是控制输入,wt​ 是过程噪声。

观测方程:描述如何从状态变量获得观测值。通常形式为:

\mathbf{y}_t=\mathbf{C}\mathbf{x}_t+\mathbf{D}\mathbf{u}_t+\mathbf{v}_t

其中,yt 是时刻 t 的观测向量,C 是观测矩阵,D 是直接传输矩阵,vt 是观测噪声。

同样,如果简化噪声,状态方程可以表示系统的状态如何随着时间的推移和输入的变化而变化。

\mathbf h'(t)=\mathbf A\mathbf h(t)+\mathbf B\mathbf x(t)

  • h(t) 是状态向量,表示系统在时间 t 的状态。

  • A 是状态转移矩阵,描述了系统的动态特性。

  • B 是控制输入矩阵,描述了输入 x(t) 如何影响状态。

  • x(t) 是控制输入向量,表示在时间 t 的外部输入。

  • h′(t) 是状态向量对时间 t 的导数,表示状态的变化率。

\mathbf{y}(t)=\mathbf{Ch}(t)+\mathbf{Dx}(t)

  • y(t) 是输出向量,表示在时间 t 的系统输出。

  • C 是输出矩阵,描述了状态向量 h(t) 如何映射到输出 y(t)。

  • D 是直接传输矩阵,描述了输入 x(t) 直接对输出 y(t) 的影响。

这个方程表示如何从状态向量和输入向量计算系统的输出。

状态方程和输出方程

可以注意到,输入和输出此时都是连续的,但是实际应用到深度学习模型当中,需要进行离散化,比如NLP中的单词token输入,CV中的像素块输入。

这里涉及到一部分数学推理,不过有大学高数知识就可以解决。

离散化的常用方法是通过离散化时间步长 Δt 将连续系统转换为等间隔的离散系统。

零阶保持技术(Zero-order hold technique)

首先,决定系统的离散时间间隔,即每次采样的时间间隔 Δt。

使用零阶保持器(Zero-Order Hold, ZOH,假设在每个采样间隔内输入信号保持不变)法进行离散化:

\mathbf{h}[k+1]=\mathbf{A}_d\mathbf{h}[k]+\mathbf{B}_d\mathbf{x}[k]

现在要求解Ad和Bd,让我们回顾一下简单的线性常微分方程:

\mathbf{h}'(t)=\mathbf{Ah}(t)

这是一个线性齐次微分方程。我们期望找到一个解 h(t),它表示系统状态随时间 t 的变化。

对于一阶标量线性常微分方程:

h'(t)=\lambda h(t)

解的形式是:

h(t)=h(0)e^{\lambda t}

求解过程如下所示:

大学高数的知识

这个解表示,状态 h(t)在时间 t 处的值是初始状态 h(0) 乘以指数增长(或衰减)因子e^{\lambda t}

对于向量和矩阵的情形,我们希望找到一个类似的形式。那么,设A是一个矩阵。

\mathbf{h}(t)=e^{\mathbf{A}t}\mathbf{h}(0)

因此得到Ad的解:

\mathbf{A}_d=e^{\mathbf{A}\Delta t}

为了推导 Bd 需要考虑在一个离散时间步长内,输入 x(t) 对状态 h(t) 的影响。

\mathbf{h}^{\prime}(t)=\mathbf{A}\mathbf{h}(t)+\mathbf{B}\mathbf{x}(t)

求解得到:

\mathbf{h}(t)=e^{\mathbf{A}t}\mathbf{h}(0)+\int_0^te^{\mathbf{A}(t-\tau)}\mathbf{B}\mathbf{x}(\tau)d\tau

右边的第一项很好理解,第二项B那边是设置了一个特解,状态转移矩阵 eA(t−τ) 描述了系统从时间 τ 到时间 t 的自由响应。积分 ∫0t​ 累积了所有过去时刻的输入对当前状态的影响。

\mathbf{h}[k+1]=\mathbf{A}_d\mathbf{h}[k]+\mathbf{B}_d\mathbf{x}[k]

先考虑系统状态从 t=kΔt 到 t=(k+1)Δt 的变化。可以表示为:

\mathbf{h}((k+1)\Delta t)=e^{\mathbf{A}\Delta t}\mathbf{h}(k\Delta t)+\int_{k\Delta t}^{(k+1)\Delta t}e^{\mathbf{A}((k+1)\Delta t-\tau)}\mathbf{B}\mathbf{x}(\tau)d\tau

在零阶保持假设下,输入 x(t) 在时间间隔 [kΔt,(k+1)Δt] 内保持不变,即 𝑥(𝜏)=𝑥[𝑘] 。这样,积分变为:

\int_{k\Delta t}^{(k+1)\Delta t}e^{\mathbf{A}((k+1)\Delta t-\tau)}\mathbf{B}\mathbf{x}[k]d\tau

为了简化积分,进行变量替换:设 τ=kΔt+s,其中 s 的范围是从 0 到 Δt。积分变为:

\int_0^{\Delta t}e^{\mathbf{A}(\Delta t-s)}\mathbf{Bx}[k]ds

注意到 x[k] 是常数,可以提到积分符号外:

\left(\int_0^{\Delta t}e^{\mathbf{A}s}ds\right)\mathbf{B}\mathbf{x}[k]

于是离散时间输入矩阵 Bd 的定义为:

\mathbf{B}_d=\left(\int_0^{\Delta t}e^{\mathbf{A}s}ds\right)\mathbf{B}=\mathbf{A}^{-1}\left(e^{\mathbf{A}\Delta t}-\mathbf{I}\right)\mathbf{B}

具体求解过程如下:

配凑的思想,前提是可逆

从而,就实现了SSM的离散化,这样就可以用在NLP的序列到序列预测了。

最终求解答案

对应到网络结构图中:

SSM的网络结构图

可以看到,这里的D其实本质就是跳跃连接,后续讨论的时候我们可以先忽略。

简化的SSM

连续到离散的SSM

Mamba中的状态空间模型计算效率最大的提升就在于其序列运算可以卷积化,从0时刻开始向后推导状态空间模型几个时刻后的输出,可以得到如下的形式:

SSM递推

此形式可以由卷积运算得到,设计适当的卷积核即可将序列运算转化为卷积运算,形式如下所示:

\begin{aligned}\overline{\mathbf{K}}&=\left(C\bar{B},C\overline{AB},\ldots,C\bar{A}^k\bar{B},\ldots\right)\\&y=x*\overline{\mathbf{K}}\end{aligned}

可以把K理解为卷积核,然后里面的值就是卷积核上面的模板。

1.2 SSM的长效依赖——HiPPO矩阵

为了解决RNN中存在的不能对历史信息保持长效依赖的问题,作者对矩阵A用了特殊的初始化技巧(HiPPO矩阵),将迄今为止看到的所有输入信号压缩为系数向量。矩阵 A 构建一个状态表示,可以很好地捕获最近的token并衰减旧的token。其公式可以表示如下:

A_{nk}=-\left\{\begin{array}{ll}(2n+1)^{1/2}(2k+1)^{1/2}&n>k\\n+1\quad n=k\\0\quad n<k\end{array}\right.

1.3 SSM的进化——选择性扫描

然而,SSM在过去一直没有被重视,Mamba在其基础上进行了优化。

传统状态空间模型的时序结构导致了其输出状态完全依赖有序的输入数据。一旦输入数据增减,或者顺序有所变化,那么状态空间模型就无法进行处理。

比如Copy任务

Copy任务是SSM擅长的,因为它可以卷积化,卷积的权值共享性质,导致输出肯定不会有变化。

对于增减过、顺序打乱过的输入,在一些不相关数据混杂在序列中出现时,状态空间模型就无法对其进行有效处理。

选择性任务

Mamba针对这一情况进行了改进,在对B C矩阵进行计算时,加入了选择性机制,即在计算是引入一个额外的线性层,对输入的输入的控制量和状态量进行选择,加强模型对不同输入形式的适应能力,算法流程如下图所示。

Mamba的算法改进

在S6中,B、C、Δ 参数都是输入相关的(都是x经过线性层得到),随着时间变化。这使得模型可以根据输入序列的变化动态调整,从而更好地适应复杂的时间序列数据。

Mamba模型在推理时,可根据不同的输入数据(x)动态计算矩阵B、C和步长Δ(映射与随机参数共同决定Δ)的值,但用于这些计算的参数(即决定如何计算这些矩阵和步长的函数或映射)是固定不变的。这些参数在训练阶段确定,并在推理阶段被重用。

S4和S6的对比

1.4 SSM的进化——并行扫描(parallel scan)

这个主要是计算上的优化,因为现在Mamba放弃了使用卷积并行计算,于是自己设计了一种可以并行运算的方法,讲解来源于 【汇报】 Mamba模型及其公式推导_哔哩哔哩_bilibili

计算流程

示意图

1.5 Mamba的架构

Mamba的架构图

线性投影:

  • 提升输入嵌入的维度,捕获更细致、更复杂的特征。

  • 将原始输入数据映射到新的特征空间,使后续处理更有效。

卷积操作:

  • 提取局部特征,识别序列中的局部模式和结构。

  • 与SSM的长期依赖捕捉能力互为补充,增强模型整体性能。

  • 保留和利用上下文信息,防止独立的token计算,确保上下文信息在处理过程中的传递。

二、Vision Mamba

论文链接:Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model

代码链接:https://github.com/hustvl/Vim

作者:Lianghui Zhu, Bencheng Liao, Qian Zhang, Xinlong Wang, Wenyu Liu, Xinggang Wang

发表单位:华中科技大学、地平线机器人、北京人工智能研究院

会议/期刊:ICML 2024

2.1 整体流程

ViM的整体框架

可以看到,ViM基本是拿Mamba当Transformer用的,所以可以简单介绍一下流程:

首先,将图像 \mathbf{t}\in\mathbb{R}^{\mathrm{H}\times\mathrm{W}\times\mathrm{C}} 切块,转换为patch,也就是 \mathbf{x_p}\in\mathbb{R}^{\mathrm{J}\times(\mathrm{P}^2\cdot\mathrm{C})},其中(H,W)是输入图像的大小,C是通道数,P是图像块的大小,J是patch的数量。

接下来,将xp线性投影到大小为d的向量,并添加位置嵌入 \mathbf{E}_{pos}\in\mathbb{R}^{(J+1)\times D}

\mathbf{T}_{0} =[\mathbf{t}_{cls};\mathbf{t}_{p}^{1}\mathbf{W};\mathbf{t}_{p}^{2}\mathbf{W};\cdots;\mathbf{t}_{p}^{\mathrm{J}}\mathbf{W}]+\mathbf{E}_{pos},

其中t jp是t的第j个patch, \mathbf{w}\in\mathbb{R}^{(\mathbb{P}^2\cdot\mathbb{C})\times\mathbb{D}} 是可学习的投影矩阵。同时,和ViT一样的是,采用t cls表示整个patch序列,用于存储最后的预测分类结果,

然后,patch序列T i-1发送给ViM编码器的第i层,获得输出Tl。最后,要对输出类tokenT 0L进行归一化,将其输入到MLP中得到最终的预测结果。

\begin{aligned}&\mathbf{T}_{l}=\mathbf{Vim}(\mathbf{T}_{1-1})+\mathbf{T}_{1-1},\\&\mathbf{f}=\mathbf{Norm}(\mathbf{T}_{\mathrm{L}}^{0}),\\&\hat{p}=\mathbf{MLP}(\mathbf{f}),\end{aligned}

原始的Mamba为一维序列而设计,其不适合需要空间感知理解的视觉任务,因此vim block结合了视觉任务的双向序列建模。

ViM Block的处理流程

序列中的每个元素都会被同时在两个方向(前向和后向)上进行处理,每个方向都可能有不同的状态空间参数。通过这种方式,模型能够同时考虑到来自序列前端和后端的信息,从而更全面地捕捉和利用图像中的空间和上下文信息。

2.2 双向SSM处理流程

步骤1: 输入序列标准化

在处理任何序列之前,首先通过一个归一化层(例如Layer Normalization)对输入序列进行标准化处理。这有助于稳定训练过程并改善模型性能。

步骤2: 线性投影

标准化后的序列被线性投影到两个不同的空间x和z,分别用于之后的双向处理和门控机制。这一步骤是通过两个不同的线性层实现的,分别对应算法中的Linearx和Linearz。

步骤3: 双向处理

处理分为两个方向——正向(forward)和反向(backward),有点像双向LSTM。对于每个方向:

  • 卷积处理:先对序列x应用一维卷积(Conv1d),这有助于捕捉局部依赖关系。得到的结果为x'。

  • 线性变换:然后将卷积后的结果x'进一步通过三个线性层变换得到状态空间模型的三个关键参数:B, C, Δ。这里的Δ通过softplus操作确保其为正值,这是因为它将用于计算时间尺度转换。

步骤4: 状态空间转换和处理

使用转换后的Δ来调整状态空间模型的演化矩阵A和输入矩阵B。实际计算中,Δ作为一个缩放因子(通过与一个预设的参数矩阵ParameterA相乘)来调整这些矩阵。完成这一变换后,应用状态空间模型计算最终的输出y。

步骤5: 门控和输出合并

正向和反向的输出通过一个门控机制与z空间相乘(使用SiLU函数作为激活函数),然后将两个方向的结果相加,得到最终的序列输出。这一步骤通过线性层LinearT和残差连接完成最终的输出序列。

三、MambaOut

论文链接:MambaOut: Do We Really Need Mamba for Vision?

代码链接:https://github.com/yuweihao/MambaOut

作者:Weihao Yu, Xinchao Wang

发表单位:新加坡国立大学

会议/期刊:暂无

3.1 论文的假设

  • 假设1:图像分类不需要SSM,因为该任务既不符合长序列也不符合自回归特征。

  • 假设2:SSM 可能对对象检测、实例分割和语义分割有潜在的好处,因为它们遵循长序列特征,尽管它们不是自回归的。

3.2 针对假设的讨论

Attention和RNN的对比

3.2.1 长序列任务的适用性

如上图所示,因果注意力机制会将所有先前令牌的键(k)和值(v)存储为记忆。每当新的令牌输入时,它的键和值就会被添加到记忆中。这种方式的记忆是无损的,意味着没有信息会在转移过程中丢失,因此可以非常精确地保存和利用历史信息。然而,这种记忆方式的缺点在于,随着序列长度的增加,整合旧记忆和当前令牌的计算复杂性也会增加。这使得因果注意力机制能够有效管理短序列,但在处理较长序列时可能会遇到困难。

与因果注意力机制不同,RNN类模型通过将所有先前的令牌压缩到一个固定大小的隐藏状态h中来处理记忆,这个隐藏状态h充当了记忆的角色。由于这种记忆的大小是固定的,它是有损的,即它不能保存每一个细节,因为新的输入必须不断地更新这个固定大小的存储空间。这种有损记忆的缺点是它无法与无损记忆的注意力模型直接竞争保持所有历史信息的能力。

然而,RNN类模型在处理长序列方面展现了独特的优势。由于隐藏状态的大小固定,与当前输入合并旧记忆的复杂度保持不变,不会因为序列长度的增加而增加。这一特点使得RNN类模型在处理长序列任务时,尤其是在有限计算资源的环境下,比注意力机制更具有优势。

SSM是类RNN的。因此,SSM的记忆本质上是有损的,所以从逻辑上来说它达不到注意力的无损记忆。因此,Mamba无法展示其在处理短序列方面的优势,而在这个领域,注意力很容易表现良好。然而,在涉及长序列的场景中,注意力会因其二次复杂度而动摇。在这种情况下,Mamba 可以明显突出其将内存与当前输入合并的效率,从而顺利地管理长序列。因此,Mamba 特别适合处理长序列。

两种Token混合模式

尽管 SSM的循环性质允许 Mamba 有效地处理长序列,但它引入了一个重大限制:ht(隐藏状态)只能访问来自先前和当前时间步的信息(1~t个时间步,如图a的右边)。如上图所示,这种类型的令牌混合称为因果模式,可以表述为:

y_{t}=f(x_{1},x_{2},...,x_{t}),

其中xt和yt分别表示第t个token的输入和输出。由于其因果性质,该模式非常适合自回归生成任务。

另一种模式称为Fully-visible模式,其中每个token可以聚合所有先前和后续token的信息(如图a的左边)。这意味着每个token的输出取决于所有token的输入:

y_{t}=f(x_{1},x_{2},...,x_{t},...,x_{T}),

其中T表示令牌总数。Fully-visible模式适合理解任务,模型可以立即访问所有输入。

默认情况下,注意力处于Fully-visible模式,但通过将因果掩模应用于注意力图,它可以轻松转变为因果模式。

\begin{aligned}&h_{t}=\mathbf{A}h_{t-1}+\mathbf{B}x_{t} ,\\&y_{t}=\mathbf{C}h_{t},\end{aligned}

由于其循环特性,类 RNN 模型本质上以因果模式运行,如 Mamba 方程 上面所示。由于这种固有特性,类 RNN 模型无法转换为Fully-visible模式。尽管 RNN 可以使用双向分支来近似完全可见的模式,但每个分支仍然单独保持因果模式。因此,由于其循环属性的固有限制,Mamba 非常适合需要因果令牌混合的任务。

总之,Mamba 非常适合具有以下特征的任务:

  • 任务涉及处理长序列;

  • 任务需要因果标记混合模式。

3.3 视觉任务是否符合上面的2个结论?

3.3.1 是否是长序列?

首先,定义了长序列任务的概念,即处理的数据序列长度超出常规处理范围,给计算带来显著挑战。长序列的判定依赖于序列长度与模型处理能力之间的关系。

通过考虑一个典型的Transformer块,其多层感知机(MLP)比率为4,分析了计算复杂度。

假设输入的维度是 L×D(L 是令牌长度,D 是通道或嵌入维数),该块的浮点运算次数(FLOPs,在之前的一篇博客内容中有提到ViT的计算)为:

\mathrm{FLOPs}=24D^{2}L+4DL^{2}.

从中提取出关于 L 的二次项和线性项的比率:

r_L=\frac{4DL^2}{24D^2L}=\frac L{6D}

这个比率用于判断处理的序列是否足够长,即当 L > 6D 时,认为是长序列。为什么这么定义呢?因为当 L > 6D 时,其处理复杂度主要受二次项的影响,而这通常需要特别的模型架构或优化方法来有效处理。

具体任务分析:

  • 图像分类在ImageNet上的应用:对于224x224的输入图像,如果将其切分为16x16的patch,则每张图像分解成196个patch。这个数字远小于长序列的阈值(例如,对于ViT-S模型,阈值为2304),表明ImageNet上的图像分类任务不是一个长序列任务。

  • 目标检测和实例分割在COCO上的应用:对于800x1280的输入图像,切分为16x16的patch,产生的令牌数量约为4K,这超过了小型模型的阈值2304,接近或等于较大模型的阈值4608,因此可以认为是长序列任务。

  • 语义分割在ADE20K上的应用:对于512x2048的输入图像,同样切分为16x16的patch,产生的令牌数量也约为4K,同样满足长序列任务的条件。

3.3.2 是否需要因果标记混合模式?

视觉识别被归类为一种理解任务。在这类任务中,模型通常需要对整个图像进行分析,以便捕捉和理解图像中的各种视觉元素和它们之间的关系。因此,这些任务不需要因果模式的限制,因为添加这种限制可能会降低模型的性能。例如,当Vision Transformers(ViT)被限制只能使用因果模式时,会观察到其性能明显下降。

BERT和ViT这类模型主要用于理解任务,通常采用Fully-visible模式,而GPT系列和Image GPT这类模型则更多地用于生成任务,通常采用因果模式。这表明视觉识别任务(如图像分类、目标检测等),由于其本质上是理解全局内容的需求,因此更适合使用Fully-visible模式,而不是因果模式。

3.3.3 进一步的假设
  • 假设1:在ImageNet上没有必要引入SSM进行图像分类,因为该任务不满足第1点或特第2点。

  • 假设2:尽管不满足第2点,但仍然值得进一步探索SSM在检测和分割方面的潜力,因为这些任务符合第1点。

3.4 实验验证

MambaOut的网络结构(input的demo是余华老师的小狗头像哈哈)

门控卷积的设计代码

分辨率为 224x224 时模型在 ImageNet 上的性能

显然,ViM模型的性能还达不到 MambaOut 的水平,更不用说超越最先进的卷积或卷积注意力混合模型了。

使用 Mask R-CNN 在 COCO 上进行对象检测和实例分割的性能。 MAC 是通过输入大小 800×1280 来测量的。

UperNet在 ADE20K验证集上的语义分割性能。 MAC 是通过输入大小 512×2048 来测量的。

ADE20K 上语义分割的性能趋势与 COCO 上的对象检测类似。 MambaOut 可以胜过一些视觉 Mamba 模型,但无法与最先进的 Mamba 模型的结果相匹配。

3.5 总结

本文提出结论如下:

  • Mamba适合长序列、自回归特征的任务;

  • Mamba不适合分类任务,但是在检测和分割上有探索的潜力。

  • 12
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值