从状态空间SSM到Mamba

目录

比较四种网络架构——CNN、RNN、Transformer、Mamba

状态空间模型SSM

1 - 状态空间可以表示神经网络

2 - 离散的SSM

S4【结构化状态空间序列=SSM+HiPPO+Structured Matrices】

S5【多输入、多输出SSM】

S6【SSM+Selection机制 】

【Mamba Block的代码】

【Mamba的并行扫描操作】

【Mamba的硬件改进】

VMamba【视觉任务中的 Mamba】

2D 选择性扫描(SS2D=S6+CSM)

​编辑

VSS视觉状态空间(VSS Block = 卷积+Silu激活+SS2D)

Sigma【多模态Mamba】


比较四种网络架构——CNN、RNN、Transformer、Mamba

网络结构特点(优势)局限性
CNN
  1. 参数共享:CNN中的卷积核在整个图像中共享参数,这种权重共享机制显著降低了计算复杂度,并简化了训练过程。
  2. 空间层次结构:CNN通过多层卷积和池化层自动学习输入数据的有用特征,使得CNN能够处理不同尺寸的图像。
  3. 强大的特征提取能力:CNN尤其适用于图像识别和分类任务等
  4. 并行计算:CNN中卷积和池化操作可以并行计算,因此在GPU等硬件上具备良好的高效能。

局部感受野:CNN利用局部感受野的思想,通过卷积操作从图像中提取特征,这使得它在处理图像中的局部特征时非常有效,但是不能考虑到全局的特征。

计算复杂性:由于CNN的层数较多,参数量较大,导致模型的计算复杂性较高。

RNN
  1. 捕捉长期依赖关系:RNN具有记忆能力,可以捕捉序列数据中的长期依赖关系,适用于需要考虑上下文信息的任务。
  2. 动态长度输入:RNN可以接受不定长度的输入,这使得它在处理时间序列数据时非常灵活。

只能关注于较短的上下文,对于长时序难以捕获到关系

同时在训练参数的时候,由于参数共享和多次连乘的特性,容易出现梯度消失或梯度爆炸的问题,导致模型难以训练或无法收敛。

另外RNN的计算过程是顺序的,导致在训练和推理阶段计算复杂度较高。同时它对超参数敏感和可解释性较差

Transformer
  1. 自注意力机制:Transformer通过自注意力机制能够同时处理整个输入序列中的所有位置,捕捉长距离依赖关系,从而更准确地理解文本含义。
  2. 并行计算:Transformer能够并行处理整个输入序列,显著提高了训练速度和效率。
  3. 灵活性和通用性:Transformer适用于各种序列到序列的任务,包括文本、图像等,并且具有很强的建模能力和通用性。
  4. 位置关联操作不受限:Transformer的位置关联操作不受限,建模能力强,通用性强,可扩展性强
Transformer模型中自注意力机制的计算量会随着上下文长度的增加呈平方级增长
Mamba
  1. 处理长序列数据:Mamba使用选择性状态空间来处理序列,解决了Transformer在处理长序列时的计算效率问题。注意力层可以随着序列长度而线性缩放,实现了更快的推理速度。
  2. 计算效率:Mamba具有较高的计算效率,特别是在GPU上运行时,数据存取交互更为友好,Mamba的推理速度也显著优于Transformer(5倍)。
  3. 选择机制:Mamba通过集成选择机制来控制记忆范围,从而提高模型的泛化能力,并且提供了在CV领域的应用潜力。
  4. 可扩展性:Mamba架构具有出色的可扩展性,允许模型更容易地扩展到更大的尺寸,而不会牺牲性能。
安装和环境配置问题:在实际应用中,Mamba模型及其相关组件(如causal-conv1d和mamba_ssm)的安装过程中常常遇到各种错误和兼容性问题,特别是在Windows系统上,这些问题包括连接超时、gcc版本不正确以及依赖不完整等。

Mamba核心创新点: ①注意力层可以线性增长;②可以处理长序列的数据;③推理速度可以达到Transformer的5倍

状态空间模型SSM

1 - 状态空间可以表示神经网络

State Space Model - h(t)为某一时刻的状态变量,导数h'(t)、输出y(t)都与当前时刻的状态h(t)和输入x(t)有关,数学表达为:

其中,t为时间,状态空间中将连续序列作为输入并预测输出序列。

即,状态空间SSM的作用是:① 输出伴随输入的变化而变化;② 模型会存储内部状态,并利用4个关键的矩阵A,B,C,D 改变内部状态和输出

参考Mamba:2 状态空间模型 (qq.com),一个简单的理解:

1. 输入序列x(t) —(例如游戏中的操作,直冲还是挥拳)

2. 根据输入映射到潜在状态h(t) —(例如,(操作不当导致)快要挂了的状态)

3. 并预测输出序列y(t) —(例如,要放大招反击)

同样地,大模型中的神经网络其实也有隐藏状态(上下文信息),当预测t时刻的输出(下一个字符Token)的时候也要根据当前t时刻的输入(当前的token),同时结合隐藏状态信息(上下文信息)。因此  状态空间与大模型技术的实现过程是类似的

2 - 离散的SSM

神经网络的输入是离散的向量,而状态空间的输入是随时间不断变化的连续状态h(t),所以需要进行“连续到离散”的处理,得到离散SSM。

zero-order hold (零阶保持,ZOH) 是传统数模转换器(DAC)的实际信号重建的数学模型。它通过将每个样本值保留一个样本间隔Δ来将离散时间信号转换为连续时间信号。在Mamba中被采用的离散化的方法是ZOH。

矩阵 D提供从输入到输出的直接信号。这通常也称为跳跃连接,Mamba的模型是在没有跳跃连接的情况下进行设计的,即Mamba包括4个参数矩阵A、B、C 和 Δ

Mamba中A、B、C 和 Δ 的含义:

A:状态转移矩阵,描述状态之间的转换关系,用于确定有多少隐藏状态应该从一个token前 向传播到另一个token。

B:输入到状态的映射矩阵,用于确定有多少输入进入隐藏状态。

C:状态到输出的映射矩阵,将状态映射到输出,用于确定隐藏状态如何转化为输出。

Δ:状态空间的选择性参数,修改A和B矩阵的权重用于控制模型对输入的关注程度。

在Mamba的论文中,将SSM的连续矩阵参数 (Δ,A,B) 调整为离散化参数矩阵(\overline{A}\overline{B}

在参数由(∆,A, B, C)变换为(\overline{A}, \overline{B},C)后,模型可以用两种方式计算,即线性递归(Recurrent,类似RNN)全局卷积(Convolutional,可并行计算)

图片粘贴自:Mamba 基础讲解【SSM,LSSL,S4,S5,Mamba】_mamba s4-CSDN博客

对于卷积计算:

整体的计算流程中,K表示为卷积;SSM的计算复杂度为O(n) 但是若将Δ,A,B,C应用于所有输入,模型缺少灵活性。故引入“选择性”【体现在下面的S6介绍】

S4【结构化状态空间序列=SSM+HiPPO+Structured Matrices】

HiPPO(High-order Polynomial Projection Operators)起源于函数逼近,就是用一系列更简单的函数来逼近一个复杂的函数。即傅里叶变换,它用简单的不同频率的正弦曲线的线性组合来逼近复杂函数。

在原始的HiPPO论文中,研究人员用了多项式作为基函数。函数逼近理论基于泛函分析,可以将它看成是定义在函数空间(而不是向量空间)上的线性代数版本。

泛函分析中将函数f(t)视为无穷维度的向量,基函数e_i(t)的线性组合:

f(t) = \sum_{i=1}^{\infty }c_ie_i(t)

Structured State Spaces for Sequences —— S4依靠着魔法矩阵A(HiPPO Matrix记住序列的所有历史,因此擅长对长序列建模,核心技术就是在线函数逼近,目的是使用HiPPO处理远程依赖

1. 将输入序列视为时间的连续函数(离散点可以拟合出连续的曲线)

2. 用一组预定义的简单函数的线性组合来近似模拟出输入函数

3. 只存储(线性组合的)系数作为状态

4. 系数可以通过求解线性常微分方程而(ODE)获得,其中方程参数就是魔幻矩阵A(最重要 的参数矩阵,用HiPPO Matrix表示)

5.当将连续时间转换为离散时间,线性ODE就成为了RNN,函数逼近的系数成为 该RNN的隐藏状态

6. 在结构化状态空间模型(S4)中,矩阵A、B和C与输入x无关,因为它们的维数N和D是静态的,不会改变

S5【多输入、多输出SSM】

S5是基于S4层的设计基础,S4层使用许多独立的单输入、单输出SSM,而 S5层使用一个多输 入、多输出SSM。S5利用S4来设计自身的初始化和参数化。S5进行高效且大范围的并行扫描计 算,Mamba模型可以独立地对每个通道(channel)应用状态空间模型(SSM),并且每个维度(dimension)都有自己独立的SSM,这样就能够针对不同的输入通道进行专门的处理。

S6【SSM+Selection机制 】

Selective Scan Space State Sequential Model,与ViT中注意力机制不同,S6将1D向量中的每个元素(如文本序列)与在此之前扫描过的信息进行交 互,从而有效地将二次复杂度降低到线性。

类似于Transformer中的 x → QKV,这里是将x通过线性变 换转化为BCΔ:(Δ的作用类似于遗忘门,起到前面时刻遗忘多 少比例的调节作用)。在Mamba中,状态矩阵B,C,Δ不再是独立的超参数,二是通过线性变换由输入x计算得到,只有A是 data dependent的,A和Δ是在训练过程中可学习的(含有Parameter)。

【Mamba Block的代码】

可以参考:mamba-minimal/model.py at master · johnma2006/mamba-minimal · GitHub   (但是该代码中的状态空间包含跳跃连接的矩阵D)

【Mamba Block示意图】

【多个Mamba Block共同作用得到网络结构Mamba】

【Mamba的并行扫描操作】

上述的选择性保留信息也带来了一些问题:由于这些矩阵(B,C,∆)现在是动态的,它们不能使用卷积表示(固定的核)进行计算。我们只能使用递归表示,而失去了卷积提供的并行化。

解决:并行化扫描

区别于Recurrent(递归加和,相当于RNN),parallel scan 通过如图所示的并行扫描、加和,使得变量可以并行化计算

【Mamba的硬件改进】

摘自:Mamba 基础讲解【SSM,LSSL,S4,S5,Mamba】_mamba s4-CSDN博客

1)核融合 kernel fusion

GPU的一个缺点是它们在小型但高效的SRAM和大型但略低效率的DRAM之间的传输(IO)速度有限。频繁地在SRAM和DRAM之间复制信息成为瓶颈。

与Flash Attention一样,Mamba试图限制从DRAM切换到SRAM的次数,反之亦然。它通过核融合 来实现这一点,核融合允许模型防止写入中间结果,并持续执行计算,直到完成。

2)重计算 (recomputation)
    中间状态不保存,但对于反向传递计算梯度是必要的。相反,作者在反向传递期间重新计算这些中间状态。虽然这看起来效率不高,但与从相对较慢的DRAM读取所有中间状态相比,它的开销要小得多。

VMamba【视觉任务中的 Mamba】

论文地址:https://arxiv.org/abs/2401.10166

VMamba 成功有效降低注意力复杂性的关键概念继承自选择性扫描空间状态顺序模型(Selective Scan Space State Sequential Model,S6),该模型最初设计用于处理自然语言处理(NLP)任 务。与传统的注意力计算方法不同,S6 使得 1-D 数组中的每个元素(例如文本序列)能够通过压缩的 隐藏状态与先前扫描的任何样本进行交互,有效地将二次复杂性降低到线性。

2D 选择性扫描(SS2D=S6+CSM)

然而,由于视觉数据的非因果性质,直接将这种策略应用于patch化和flatterned的图像将不可避免地导致受限的感受野,因为无法估计相对于未扫描的patch的关系。这个问题称为 “方向敏感” 问题,VMamba通过新引入的交叉扫描模块(Cross-Scan Module,CSM)来解决它。CSM 不是以单向模式(列向或 行向)遍历图像特征映射的空间域,而是采用四向扫描策略,即从特征映射的四个角到相对位置(见下图 2)。这种策略确保特征映射中的每个元素从不同方向的所有其他位置集成信息,从而产生全局感受野,而不增加线性计算复杂性。

我们选择 沿行和列展开图像补丁成序列(扫描扩展)(图b),然后沿四个不同方向进行扫描:从左上到右下, 从右下到左上,从右上到左下,从左下到右上。这样,任何像素(如图 2 中的中心像素)都会从不同方 向的所有其他像素中集成信息。然后,我们将每个序列重新整形成单个图像,并将所有序列合并成一个 新的序列,如图 3 所示(扫描合并)。

VSS视觉状态空间(VSS Block = 卷积+Silu激活+SS2D)

输入特征经过一系列线性投影 (Linear)、深度卷积(DWConv)作为原始Mamba,然后使用SS2D选择性扫描模块来模拟特征中的长距离空间信息,并通过残差连接。

将 S6 与 CSM 集成,作为构建视觉状态空间(Visual State Space,VSS)块的核心元素, 构成了 VMamba 的基本构建块。

Sigma【多模态Mamba】

论文地址:https://arxiv.org/abs/2404.04256

代码地址:GitHub - zifuwan/Sigma: Python implementation of Sigma: Siamese Mamba Network for Multi-Modal Semantic Segmentation

Sigma的本质是三个由Mamba模块组成的结构——用于特征提取的双编码器、多模态融合模块、以及一个通道感知的解码器

编码器backbone:级联的Visual State Space(VSS)块与下采样来从各种模态提取多尺度全局信息

融合模块:Cross Mamba Block(CroMB)+ Concat Mamba Block(ConMB)

解码器:Channel-Aware Mamba Decoder(CVSS),其实就是在VSS模块中引入了一个由平均池化和 最大池化组成的通道注意力操作

编码和解码很好理解,模块结构如上图所示,与VMamba相似,就是用VSS Block取代CNN,Transformer等结构进行特征的提取;感觉文章的重点是多模态的融合模块的设计

融合模块结构示意图:

CroMB(Cross Mamba Block)

CroMB模块的目的是增强来自不同模态的特征表示,具体来说,输入特征首先通过线性层和深度卷积层分别处理,然后送入Cross Selective Scan Module,通过交叉乘法机制来实现。

  • Cross Selective Scan Module: 这个模块利用Mamba的选择机制,通过输入数据生成系统矩阵 B、C和Δ【相当于Transformer中的交叉注意力QKV矩阵】,使模型能够根据输入的上下文来调整自己的行为。这里使用线性投影层来生成这些矩阵。
  • 特征交互: 输入特征通过互补模态的C矩阵在Selective Scan操作中使用,这允许SSM根据另一种模态的引导从隐藏状态重建输出。【如上图的Cross SS中标红的C矩阵,C_rgb用在X向量是状态空间方程中】

ConMB(Concat Mamba Block)

ConMB模块的作用是在CroMB的基础上进一步整合特征,以获得一个融合了两种模态重要信息的特征表示。与CroMB不同,ConMB直接处理连接的特征作为输入,这样可以尽可能保留来自两种模态的信息。

  • 特征连接与处理: CroMB的输出首先通过线性和卷积层处理,然后送入Concat Selective Scan Module。在这个模块中,两个特征首先被展平并连接成一个新的序列,然后对这个序列进行逆扫描(图片中“红”“黑”向量条的交换),以捕获两种模态的长距离依赖关系(通道数*2)。
  • Concat Selective Scan Module: 在这个模块中,连接后的特征序列被进一步处理,以捕获两种模态之间的复杂交互。
  • 输出融合: 经过扫描的特征被乘以从CroMB输出中导出的两个缩放参数,并在通道维度上连接,形成一个形状为 [高度, 宽度, 2*通道数] 的特征。最后,使用线性投影层将特征形状降低到 [高度, 宽度, 通道数] 。

  • 25
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锅小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值