一文了解Mamba和选择性状态空间模型 (SSM)

在这里插入图片描述

前言

在这篇博文中,我们将带您深入了解序列建模的演变历程,从最初的简单前馈神经网络,到 Transformer 的出现,这一架构的革新彻底改变了自然语言处理等领域的面貌。接着,我们将探讨该领域的最新进展:Mamba架构,这一新兴模型有望突破 Transformer 的一些局限性,提升处理长序列数据的效率与准确性。

如今,基于 Transformer 架构的模型已经成为深度学习领域的核心技术,支持着各种令人兴奋的应用。Transformer 的核心——自注意力机制——在多个任务中表现出色。然而,尽管 Transformer 在捕捉全局信息方面极具优势,但其计算复杂度呈二次增长,这使得它在处理高分辨率图像或密集预测任务时,尤其是在长序列输入的场景下,面临巨大的计算负担。此外,Transformer 还存在对数据量需求较高的问题,在数据量不足的情况下往往表现不佳,同时全局特性使得它对细节的敏感度较差,这在需要精细边界预测的任务中尤为明显。

为了理解transformer的局限性,下面结合图稍微解释一下

如下是一个transformer的解码器结构,只使用解码器来创建生成模型
在这里插入图片描述
在生成下一个 Token 时,我们需要重新计算整个序列的 attention ,即使我们已经生成了一些 token。

在这里插入图片描述
假设输入序列长度为 L L L,则每个 token 需要计算 L L L次注意力分数,总的注意力计算次数为 L × L = L 2 L \times L=L^2 L×L=L2 ,即计算量呈现为 O ( L 2 ) O(L^2) O(L2) 的平方复杂度。如果序列长度增加,这可能会很昂贵。

因此SSM(结构化状态空间模型)的出现,通过解决Transformer在长序列处理中的计算效率问题,得到了广泛应用。与Transformer相比,SSM具有线性复杂性,能够根据输入有选择地传播或遗忘信息。Mamba模型结合了选择性SSM,提升了推理速度,并在多个领域(如语言和基因组学)取得了优异性能,超越了同规模的Transformer模型。

与transformer的比较

特征transformerMamba
架构基于注意力机制基于 SSM
复杂度
推理速度O(n)O(1)
训练速度O(n²)O(n)

SSM

SSM简介

状态空间模型(State Space Models,简称SSM)在控制理论中传统用于通过状态变量对动态系统建模。

Aaron R. VOELKER和Chris ELIASMITH提出了一个重要问题:大脑如何有效地表示时间信息。在他们2018年发表的论文《Improving Spiking Dynamical Networks: Accurate Delays, Higher-Order Synapses, and Time Cells》中,他们发现SSM能够很好地描述大脑中存在的“时间细胞”(尤其是海马体和皮层)。他们随后将这一发现应用于深度学习,成为最早在此领域中使用SSM的研究者之一。关于这一研究的更多细节。

下文将定义深度学习中SSM的基础知识

深度学习中的SSM定义

状态空间模型(State Space Model, SSM)是一种用于描述动态系统状态随时间演变的数学模型。SSM通过一组矩阵和状态变量来描述系统如何随时间步推进。该模型通常包含状态方程和输出方程,能够在连续时间或离散时间下进行计算。
如下图所示,我们可以定义SSM的结构:

在这里插入图片描述
在SSM中有三个与时间 t t t相关的变量:

  • x ( t ) ∈ C n x(t) \in \mathbb{C}^n x(t)Cn:表示 n n n个状态变量,反映当前系统的状态。
  • u ( t ) ∈ C m u(t) \in \mathbb{C}^m u(t)Cm:表示 m m m个输入变量,作为外部影响引入系统。
  • y ( t ) ∈ C p y(t) \in \mathbb{C}^p y(t)Cp:表示 p p p个输出变量,通常是我们希望从系统中观测到的值。

此外,SSM由四个可学习的矩阵构成,矩阵的大小对应于变量维度:

  • A ∈ C n × n A \in \mathbb{C}^{n \times n} ACn×n:状态矩阵,控制状态向量的演变,影响状态随时间的更新。
  • B ∈ C n × m B \in \mathbb{C}^{n \times m} BCn×m:控制矩阵,用于将输入向量 u ( t ) u(t) u(t)作用到状态 x ( t ) x(t) x(t)
  • C ∈ C p × n C \in \mathbb{C}^{p \times n} CCp×n:输出矩阵,控制状态对输出的影响。
  • D ∈ C p × m D \in \mathbb{C}^{p \times m} DCp×m:指令矩阵,直接将输入 u ( t ) u(t) u(t)对应到输出 y ( t ) y(t) y(t)

这些变量满足以下系统流程:

  1. 状态更新:描述系统状态的动态演变
    x ′ ( t ) = A x ( t ) + B u ( t ) x′(t) = Ax(t) + Bu(t) x(t)=Ax(t)+Bu(t)
  2. 输出方程:描述如何通过状态得到系统的输出
    y ( t ) = C x ( t ) + D u ( t ) y(t) = Cx(t) + Du(t) y(t)=Cx(t)+Du(t)
  3. 状态推进:将 x ′ x′ x作为下一时刻的状态,在下一个时间步中继续迭代,从而动态模拟系统的时间演变。

在深度学习的SSM中,通常设定 D u = 0 Du = 0 Du=0(一个易于计算的跳跃连接),此时方程简化为:
x ′ = A x + B u x′ = Ax + Bu x=Ax+Bu
y = C x y = Cx y=Cx

由于该模型是连续时间下的描述,为在计算机上实现,必须先进行离散化处理,使得模型能够在离散时间步下更新状态。

离散化

离散化是状态空间模型(SSM)的核心步骤之一,它将系统的连续时间表示转化为离散时间表示,使得 SSM 可以在计算机上实现。通过离散化,SSM 的连续微分方程可以转换为递归关系或卷积形式,从而更高效地进行计算

将系统的连续时间表示转化为离散时间表示,我们可以通过一个简单的例子来说明这个过程。

一阶连续时间系统

假设有一个简单的连续时间状态空间模型,描述一个物理系统的状态更新。它的状态方程如下:
d x ( t ) d t = − a ⋅ x ( t ) + b ⋅ u ( t ) \frac{dx(t)}{dt}=-a ·x(t)+b·u(t) dtdx(t)=ax(t)+bu(t)
其中:

  • x(t)是系统的状态(例如位置、速度等)
  • u(t)是输入(例如控制信号)。
  • a和b是常数,分别表示系统的衰减因子和输入影响。

离散化的步骤
为了将其转化为离散时间系统,我们需要用一个时间步长 Δ t Δt Δt 来近似连续的状态变化。可以通过以下公式近似得到离散化后的状态更新公式:
x ( t + Δ t ) ≈ x ( t ) + d x ( t ) d t ⋅ Δ t x(t + Δt) \approx x(t) + \frac{dx(t)}{dt} · Δt x(t+Δt)x(t)+dtdx(t)Δt
将原始的微分方程代入:
x ( t + Δ t ) ≈ x ( t ) + ( − a ⋅ x ( t ) + b ⋅ u ( t ) ) ⋅ Δ t x(t + Δt) \approx x(t) + (-a · x(t) + b · u(t)) · Δt x(t+Δt)x(t)+(ax(t)+bu(t))Δt
整理一下,得到离散时间的状态更新公式:
x ( t + Δ t ) = ( 1 − a ⋅ Δ t ) ⋅ x ( t ) + b ⋅ Δ t ⋅ u ( t ) x(t +Δt) = (1-a·Δt)·x(t) +b·Δt·u(t) x(t+Δt)=(1aΔt)x(t)+bΔtu(t)
这个公式就是离散时间下的状态更新方程,它在每个时间步长 Δ t Δt Δt 后更新系统的状态 x ( t ) x(t) x(t)。可以看到,系统的状态不再是连续变化的,而是在每个固定的时间步长 Δ t Δt Δt 内更新一次。

有了离散化的过程,计算机就能够在每个固定的时间步长上更新系统的状态,而不再依赖于连续时间的微分运算。

在 SSM 中,离散化的过程一般有两种途径:

  1. 递归表示:将状态更新公式转化为离散时间的递归关系,使得系统可以逐步更新状态向量。这样,每个时间步都依赖于前一时间步的状态,使得模型能够有效地表示时间序列中的动态演变。这种递归的结构非常适合顺序数据处理,如自然语言处理或时序预测等任务。
  2. 卷积表示:通过卷积形式来描述系统的动态行为。卷积表示允许模型高效地捕捉输入序列和系统状态之间的关系,特别是对于长期依赖的时间序列建模非常有利。这种表示形式能够利用并行化的优势,大幅度加快计算速度,特别是在处理大规模数据时更为高效。

如下图所示,SSM 的结构可以从连续时间、递归、以及卷积这三种视角进行理解:
在这里插入图片描述

  • 连续时间视角:描述了状态向量 x ( t ) x(t) x(t)和输入 u ( t ) u(t) u(t)的连续变化过程。
  • 递归表示:将连续变化转化为时间步长间的递推公式,每一步的状态更新只依赖于前一时刻的状态。
  • 卷积表示:通过卷积操作来描述系统的输入和输出关系,使得模型在时间维度上的计算可以高效并行化。

通过离散化过程,SSM 在连续时间和离散时间之间架起了桥梁,使其能够在时间序列建模中既保持动态演化的精度,又能提升计算效率。

SSM的递归方法

为了实现连续时间系统的离散化,我们可以采用梯形法(Trapezoid method)。其核心思想是通过将表示一个函数 f f f在区间 [ t n , t n + 1 ] [t_n,t_{n+1}] [tn,tn+1]下的曲线区域近似为梯形来计算其面积 T T T。这种近似方法为离散化提供了一个高效的框架

梯形法则

也叫一阶牛顿-柯特斯闭型积分公式称为梯形法则(trapezoidal rule),即在式(1)中使用一次多项式

I = ∫ a b f ( x ) d x ≅ ∫ a b f 1 ( x ) d x I = \int_a^bf(x)dx \cong \int_a^bf_1(x)dx I=abf(x)dxabf1(x)dx
上面公式用直线可以表示为
f 1 ( x ) = f ( a ) + f ( b ) − f a ) b − a ( x − a ) f_1(x) = f(a)+\frac{f(b)-f{a)}}{b-a}(x-a) f1(x)=f(a)+baf(b)fa)(xa)
用这条直线下的面积作为积分 ∫ a b f ( x ) d x 的一个估计值 \int_a^bf(x)dx的一个估计值 abf(x)dx的一个估计值:
I = ∫ a b [ f ( a ) + f ( b ) − f ( a ) b − a ( x − a ) ] d x I= \int_a^b[f(a)+\frac{f(b)-f(a)}{b-a}(x-a)]dx I=ab[f(a)+baf(b)f(a)(xa)]dx
积分的结果为:
I = ( b − a ) f ( a ) + f ( b ) 2 I=(b-a)\frac{f(a)+f(b)}{2} I=(ba)2f(a)+f(b)
如图下图所示,从几何上看,梯形法则相当于用连接 f ( a ) f(a) f(a) f ( b ) f(b) f(b)的直线与坐标轴所围梯形的面积来逼近积分。在这里插入图片描述

SSM梯形法的基本原理

在梯形法中,我们通过计算两个端点KaTeX parse error: Unexpected character: '?' at position 1: ?̲?_n t n + 1 t_{n+1} tn+1处的函数值的平均,来近似曲线下的面积。其公式如下:

T = t n + 1 − t n 2 ( f ( t n ) + f ( t n + 1 ) ) T=\frac{t_{n+1}-t_n}{2}(f(t_n)+f(t_{n+1})) T=2tn+1tn(f(tn)+f(tn+1))

其中, T T T是在时间段 [ t n , t n + 1 ] [t_n,t_{n+1}] [tn,tn+1] 下的梯形面积,表示函数 f f f在该区间内的积分值。

离散过程

从连续时间系统的状态空间模型(SSM)出发,我们可以推导出离散时间的状态更新公式。假设系统的状态更新公式为:

x ′ ( t ) = A x ( t ) + B u ( t ) x′(t) = Ax(t)+Bu(t) x(t)=Ax(t)+Bu(t)

其中 x ′ ( t ) x′(t) x(t)是状态的导数,表示状态在时间 t t t变化的速率, A A A B B B分别是控制状态和输入的矩阵, u ( t ) u(t) u(t)是输入。

将上面的微分方程应用到梯形法中,我们得到:

x ( t + Δ t ) − x ( t ) = Δ t 2 ( f ( t ) + f ( t + Δ t ) ) x(t +Δt) - x(t) = \frac{Δt}{2}(f(t)+f(t+Δt)) x(t+Δt)x(t)=2Δt(f(t)+f(t+Δt))

将连续时间中的 f ( t ) f(t) f(t)用状态空间方程表示,即 f ( t ) = A x ( t ) + B u ( t ) f(t)=Ax(t)+Bu(t) f(t)=Ax(t)+Bu(t),我们得到:

x n + 1 − x n = △ 2 ( A x n + B u n + A x n + 1 + B u n + 1 ) x_{n+1} -x_n= \frac{\triangle}{2}(Ax_n+Bu_n +Ax_{n+1}+Bu_{n+1}) xn+1xn=2(Axn+Bun+Axn+1+Bun+1)
其中, Δ t = t n + 1 − t n Δt=t_{n+1}-t_n Δt=tn+1tn是时间步长。

重新整理并推导离散更新公式

通过代数变换,我们得到以下形式:
x n + 1 = x n + Δ 2 ( A x n + B u n + A x n + 1 + B u n + 1 ) x_{n+1} = x_n +\frac{Δ}{2}(Ax_n+Bu_n +Ax_{n+1}+Bu_{n+1}) xn+1=xn+2Δ(Axn+Bun+Axn+1+Bun+1)
x n + 1 x_{n+1} xn+1提取到方程的一侧,得到:
x n + 1 − Δ 2 A x n + 1 = x n + Δ 2 A x n + Δ 2 B ( u n + u n + 1 ) x_{n+1}-\frac{Δ}{2}Ax_{n+1}=x_n+\frac{Δ}{2}Ax_n+\frac{Δ}{2}B(u_n+u_{n+1}) xn+12ΔAxn+1=xn+2ΔAxn+2ΔB(un+un+1)

进一步整理,得到
( I − Δ 2 A ) x n + 1 = ( I + Δ 2 A ) x n + Δ B u n + 1 (I-\frac{Δ}{2}A)x_{n+1}=(I+\frac{Δ}{2}A)x_n+ΔBu_{n+1} (I2ΔA)xn+1=(I+2ΔA)xn+ΔBun+1
从而,最终得到离散化的更新公式:
x n + 1 = ( I − Δ 2 A ) − 1 ( I + Δ 2 A ) x n + ( I − Δ 2 A ) − 1 Δ B u n + 1 x_{n+1}=(I-\frac{Δ}{2}A)^{-1}(I+\frac{Δ}{2}A)x_n+(I-\frac{Δ}{2}A)^{-1}ΔBu_{n+1} xn+1=(I2ΔA)1(I+2ΔA)xn+(I2ΔA)1ΔBun+1
这个公式表明,在离散化后的系统中,新的状态 x n + 1 x_{n+1} xn+1是由当前状态 x n x_n xn和控制输入 u n , u n + 1 u_n,u_{n+1} un,un+1决定的。

离散化后的矩阵表示

为了简化表示,可以引入离散化后的矩阵 A ˉ , B ˉ , C ˉ \bar A,\bar B,\bar C Aˉ,Bˉ,Cˉ,它们分别对应离散时间系统中的状态转移矩阵、输入矩阵和输出矩阵。通过定义这些矩阵,我们可以得到以下形式的离散时间状态空间方程:

A ˉ = ( I − Δ 2 A ) − 1 ( I + Δ 2 A ) \bar A=(I-\frac{Δ}{2}A)^{-1}(I+\frac{Δ}{2}A) Aˉ=(I2ΔA)1(I+2ΔA)
B ˉ = ( I − Δ 2 A ) − 1 Δ B \bar B=(I-\frac{Δ}{2}A)^{-1}ΔB Bˉ=(I2ΔA)1ΔB
C ˉ = C \bar C=C Cˉ=C

离散化后的系统可以写为:

x k = A ˉ x k − 1 + B ˉ u k x_k=\bar Ax_{k-1}+\bar Bu_k xk=Aˉxk1+Bˉuk
y k = C ˉ x k y_k = \bar Cx_k yk=Cˉxk

其中, A ˉ , B ˉ , C ˉ \bar A,\bar B,\bar C Aˉ,Bˉ,Cˉ是经过离散化后的矩阵,表示离散时间状态空间模型的核心部分。

接下来是对SSM系统卷积视角的解释

SSM卷积基本原理

这个递推关系可以通过卷积来表示。为此,我们需要对系统的方程进行迭代。

首先,回顾系统的基本方程:
x k = A ˉ x k − 1 + B ˉ u k x_k=\bar Ax_{k-1}+\bar Bu_k xk=Aˉxk1+Bˉuk
y k = C ˉ x k y_k = \bar Cx_k yk=Cˉxk

分析第一个方程

对于第一个方程 x k = A ˉ x k − 1 + B ˉ u k x_k=\bar Ax_{k-1}+\bar Bu_k xk=Aˉxk1+Bˉuk,我们逐步迭代。

  • 第1步: x 0 = B ˉ u 0 x_0=\bar Bu_0 x0=Bˉu0
  • 第2步: x 1 = A ˉ x 0 + B ˉ u 1 = A ˉ B ˉ u 0 + B ˉ u 1 x_1=\bar Ax_0+\bar Bu_1=\bar A\bar Bu_0+\bar Bu_1 x1=Aˉx0+Bˉu1=AˉBˉu0+Bˉu1
  • 第3步: x 2 = A ˉ x 1 + B ˉ u 2 = A ˉ ( A ˉ B ˉ u 0 + B ˉ u 1 ) + B ˉ u 2 = A ˉ 2 B ˉ u 0 + A ˉ B ˉ u 1 + B ˉ u 2 x_2=\bar Ax_1+\bar Bu_2=\bar A(\bar A\bar Bu_0+\bar Bu_1)+\bar Bu_2=\bar A^2\bar Bu_0+\bar A\bar Bu_1+\bar Bu_2 x2=Aˉx1+Bˉu2=Aˉ(AˉBˉu0+Bˉu1)+Bˉu2=Aˉ2Bˉu0+AˉBˉu1+Bˉu2

通过类似的方式,可以得到一般的表达方式:
x k = ∑ i = 0 k A ˉ k − i B ˉ u i x_k=\sum_{i=0}^k\bar A^{k-i}\bar Bu_i xk=i=0kAˉkiBˉui

这里, x k x_k xk可以看作是一个函数 f f f,该函数由 ( u 0 , u 1 , … , u k ) (u_0,u_1,…,u_k) (u0,u1,,uk) 参数化。

分析第二个方程

接下来,考虑系统中的第二个方程 y k = C ˉ x k y_k=\bar Cx_k yk=Cˉxk,可以将之前计算的 x k x_k xk代入其中,得到输出:

  • 第1步: y 0 = C ˉ x 0 = C ˉ B ˉ u 0 y_0=\bar Cx_0=\bar C\bar Bu_0 y0=Cˉx0=CˉBˉu0
  • 第2步: y 1 = C ˉ x 1 = C ˉ ( A ˉ B ˉ u 0 + B ˉ u 1 ) = C ˉ A ˉ B ˉ u 0 + C ˉ B ˉ u 1 y_1=\bar Cx_1=\bar C(\bar A\bar Bu_0+\bar Bu_1)=\bar C\bar A\bar Bu_0+\bar C\bar Bu_1 y1=Cˉx1=Cˉ(AˉBˉu0+Bˉu1)=CˉAˉBˉu0+CˉBˉu1
  • 第3步: y 2 = C ˉ x 2 = C ˉ ( A ˉ 2 B ˉ u 0 + A ˉ B ˉ u 1 + B ˉ u 2 ) = C ˉ A ˉ 2 B ˉ u 0 + C ˉ A ˉ B ˉ u 1 + C ˉ B ˉ u 2 y_2=\bar Cx_2=\bar C(\bar A^2\bar Bu_0+\bar A\bar Bu_1+\bar Bu_2)=\bar C\bar A^2\bar Bu_0+\bar C\bar A\bar Bu_1+\bar C\bar Bu_2 y2=Cˉx2=Cˉ(Aˉ2Bˉu0+AˉBˉu1+Bˉu2)=CˉAˉ2Bˉu0+CˉAˉBˉu1+CˉBˉu2

从这些步骤可以看出,输出 y k y_k yk也是一个卷积的结果,可以表示为:
y k = ∑ i = 0 k C ˉ A ˉ k − i B ˉ u i y_k=\sum_{i=0}^k\bar C\bar A^{k-i}\bar Bu_i yk=i=0kCˉAˉkiBˉui

我们可以观察到,卷积核 K ˉ k \bar K_k Kˉk是由 C ˉ B ˉ , C ˉ A ˉ B ˉ , . . . , C ˉ A ˉ k B ˉ \bar C\bar B,\bar C\bar A\bar B,...,\bar C\bar A^k\bar B CˉBˉ,CˉAˉBˉ,...,CˉAˉkBˉ组成的因此我们有: K ∗ u K∗u Ku

现在我们了解了 SSM,我们可以看到如何使用它们来创建 Mamba 架构

Mamba:一种深度学习架构,专注于序列建模

Mamba 是一种新型的深度学习架构,专门用于序列建模。由卡内基梅隆大学和普林斯顿大学的研究人员开发,旨在解决变压器模型在处理长序列时的一些局限性。Mamba 基于结构化状态空间序列模型(S4),其设计理念通过结合连续时间、递归和卷积模型来有效处理长距离依赖问题。

论文地址:https://arxiv.org/html/2312.00752?_immersive_translate_auto_translate=1
github开源地址: https://github.com/state-spaces/mamba

架构概述

Mamba 架构的核心是S4模型,通过结合连续时间、递归和卷积模型,Mamba 能够高效地处理长序列数据,适应不规则采样的数据,处理无界上下文,同时在训练和推理时保持计算效率。Mamba 对 S4 进行了显著的增强,尤其在时间变化操作方面的处理,采用了一种独特的选择机制,根据输入数据动态调整 结构化状态空间模型(SSM) 的参数。

Mamba 在处理长序列数据时,能够根据输入数据的特征选择性地聚焦相关信息,过滤掉不相关的数据,从而提高处理效率。此外,Mamba 采用了硬件感知算法,利用 GPU 进行优化,通过内核融合、并行扫描和重计算来提高性能,并避免在内存密集型层中展开状态,从而提升了模型的训练和推理效率。

关键组件

  1. 选择性状态空间(SSM):Mamba 的核心组件,通过递归模型根据当前输入选择性地处理信息,从而专注于相关数据并丢弃不重要的信息。
  2. 架构:Mamba 采用统一的 SSM 块来代替传统变压器中的复杂注意力机制和多层感知器(MLP)块,从而降低了计算复杂度,提升了推理速度
  3. 硬件感知并行性:Mamba 采用并行算法与递归模式,专为硬件效率优化设计,这可能进一步提升模型性能。

下面将对这些组件详细介绍

选择性状态空间模型

选择性压缩与上下文信息处理

序列建模的核心挑战之一是如何在保证模型性能的同时,有效压缩和处理上下文信息。在许多现代序列模型中,尤其是 Transformer,虽然它们在处理大规模上下文时表现出色,但通常不会对上下文信息进行压缩。这意味着在推理阶段,模型需要显式地存储和操作大量的信息(例如,通过键值缓存(KV缓存))。这种方法在推理时效率较低,尤其是在输入序列较长时。相比之下,递归神经网络(RNN)等模型通过有限的状态压缩上下文信息,从而在推理阶段能够更高效地进行计算,但这也限制了它们对长程依赖的建模能力。

为了更好地理解这个问题,我们可以通过两个经典任务来解释选择性压缩的概念:

  • 选择性复制任务:在这个任务中,模型的目标是复制输入序列中某些特定的标记。然而,任务要求模型不仅要简单地复制所有的标记,还要根据上下文信息,选择性地记住和复制相关的标记。这就要求模型具备强大的内容感知能力,能够根据输入序列中的上下文动态地选择哪些信息需要保留,哪些信息可以丢弃。
  • 归纳头任务:在这个任务中,模型需要根据上下文信息检索正确的答案。这个任务实际上模拟了许多大规模语言模型(LLM)的核心工作机制,即基于给定的上下文来选择性地从大量候选信息中提取出最相关的部分。归纳头任务要求模型不仅记住输入序列,还能够根据上下文条件做出精确的选择和过滤,提供准确的输出。

这两个任务揭示了传统的线性时不变(LTI)模型的局限性。LTI模型的主要问题在于它们无法根据上下文信息进行选择和过滤。它们会在固定的时间范围内处理所有输入信息,而没有能力在长序列中选择性地关注最相关的部分。因此,LTI模型在处理复杂的任务时通常表现不佳,尤其是在需要高效上下文信息压缩和选择性过滤的任务中。

为了更好的理解这个选择性压缩的重要性,下面结合这个图我们来看

可以从 Transformer 模型的处理机制来简化理解这些两种复制任务。

Transformer 的结构擅长处理序列数据,因为它可以通过注意力机制(attention)动态地关注序列中重要的元素,从而完成任务。Transformer 能够应对"标准"和"选择性"这两种类型的复制任务。
在这里插入图片描述

1. 标准复制任务(左图)(固定间距)

在标准复制任务中,输入和输出的顺序是固定的、一一对应的。简单来说,就是“给什么输出什么”,比如输入是 A B C D E,那输出就是 A B C D E。这个过程在 Transformer 里可以实现为:

输入:A B C D E 输出:A B C D E

  • 注意力机制会默认保持顺序:在标准复制任务中,Transformer 的注意力机制可以均匀地关注每个输入单元(如 A、B、C 等),因为输入和输出之间有固定的间距,所以模型只需按顺序传递信息即可。这类似于“流水线”,每个输入单元的输出位置不变,所以 Transformer 只需要“按部就班”地把每个单元逐步复制到输出中即可。
  • 无需选择性关注:由于输入和输出完全匹配,Transformer 不用“挑选”特定单元,也不需要判断哪些元素重要、哪些可以忽略。这样,模型只要逐个顺序复制就行了。这种任务可以使用简单的结构来完成,不需要复杂的决策,注意力机制只用保持“恒定关注”每个输入元素。
2. 选择性复制任务(右图)(随机间距)

在选择性复制任务中,输入和输出并非固定对应,也就是说,模型需要“挑选”一些输入并丢弃其他的。例如,输入序列是 A B C D E F G H,但输出只选择某些元素,例如 B E H。这种任务要求模型根据上下文来决定哪些输入单元需要保留。

输入:A B C D E F G H 输出:B E H

  • 注意力机制会动态选择关注点:对于选择性复制任务,Transformer 的注意力机制就更有用了。它可以根据输入内容的不同,动态调整注意力,决定哪些元素是关键的、哪些可以忽略。在上例中,模型可能对 B、E 和 H 给更高的注意力权重,而忽略其他元素。
  • 不固定间距:因为间距是随机的,输出并不是和输入一一对应的。所以 Transformer 不会去“按顺序”复制,而是通过计算每个输入的注意力得分来选择要输出的元素。
  • 依赖上下文信息:模型通过注意力机制判断哪些元素是重要的,这时“上下文”信息(例如周围其他元素的内容)就很重要。Transformer 能够通过多层注意力机制结合上下文信息来做出“选择”,从而只复制对任务有用的元素。

那么SSM是如何改进,实现选择性的呢

选择改进 SSM

核心思路:引入“选择机制”,让模型中影响序列交互的参数(例如:RNN 中的状态更新参数,或者 CNN 的卷积核)不再是固定的,而是根据输入序列的内容动态生成。这意味着当不同的输入进入模型时,模型会自动调整参数来适应这些输入。通过这样动态的参数调整,模型能够在处理序列数据时做到“时间可变性”,即模型的行为可以随时间步长而变化,而不再局限于一个固定的行为模式。

换句话说,传统的模型在序列处理时会使用不变的参数,这对静态模式有效,但对于更复杂的、变化的序列数据则不够灵活。而“选择机制”可以使模型更“智能”,能够根据输入内容的不同,自适应调整交互参数,从而在处理动态变化的序列时更精确和灵活

算法1 :Structured SSM(S4)

此算法展示了一个标准的 SSM 模型,即时间不变的状态空间模型。在这种结构中,核心参数 Δ \Delta Δ A A A B B B C C C均作为固定参数,无论输入变化如何都保持不变,适用于具有稳定交互的序列任务。

  1. 输入输出:输入张量 x x x和输出张量 y y y的形状为 ( B , L , D ) (B,L,D) (B,L,D),其中 B B B是批次大小, L L L是序列长度, D D D是特征维度
  2. 参数矩阵
    • A A A:表示结构化的 N × N N \times N N×N矩阵。
    • B B B:与 A A A相同维度,用于序列更新。
    • C C C:负责输出映射。
    • Δ Δ Δ:时间步长,由参数化的 τ Δ \tau_Δ τΔ生成。
  3. 离散化:将连续参数 Δ , A , B Δ,A,B Δ,A,B离散化成离散形式 A ˉ , B ˉ \bar A,\bar B Aˉ,Bˉ
  4. SSM计算输出:执行 SSM 操作,生成序列的时间不变特征输出 y y y
Algorithm 1 SSM (S4)
1: x : (B, L, D)
2: y : (B, L, D)
3: A : (D, N) ← Parameter          ▷ Represents structured N×N matrix
4: B : (D, N) ← Parameter
5: C : (D, N) ← Parameter
6: Δ : (D) ← τΔ⁢(Parameter)
7: A¯, B¯ : (D, N) ← discretize(Δ, A, B)
8: y ← SSM⁢(A¯, B¯, C)⁢(x)         ▷ Time-invariant: recurrence or convolution
9: return y
改进算法2 :选择性SSM(S6)

为了适应不同输入序列,Algorithm 2 引入选择机制,将部分关键参数(如 B , C , Δ B,C,Δ B,C,Δ)动态地依赖于输入。这种设计不仅增加了灵活性,还使模型可以对不同输入条件调整状态空间参数。与标准 SSM 不同,此算法允许模型在序列的每个时间步上进行调整,实现时间可变性。

  1. 输入输出:输入张量 x x x和输出张量 y y y与算法 1 相同。
  2. 输入依赖参数
    • B B B:使用选择函数 KaTeX parse error: Unexpected character: '?' at position 1: ?̲?_𝐵(𝑥)=Linear…,使 B B B成为( B × L × N B \times L \times N B×L×N) 的输入依赖参数。
    • C C C:使用选择函数KaTeX parse error: Unexpected character: '?' at position 1: ?̲?_𝐶(𝑥) =Linea…生成输入依赖参数KaTeX parse error: Unexpected character: '?' at position 1: ?̲?
    • Δ Δ Δ:通过组合函数KaTeX parse error: Unexpected character: '?' at position 20: …_Δ(Parameter + ?̲?_Δ(𝑥))获得随输入变化的 Δ Δ Δ
  3. 离散化:同样对 Δ , A , B Δ,A,B Δ,A,B进行离散化处理,但在此结构下产生的 A ˉ , B ˉ \bar A,\bar B Aˉ,Bˉ具备时间变化特性。
  4. SSM计算时变输出:使用 SSM 进行时变的输出计算,得到结果 y y y
Algorithm 2 SSM + Selection (S6)
1: x : (B, L, D)
2: y : (B, L, D)
3: A : (D, N) ← Parameter          ▷ Represents structured N×N matrix
4: B : (B, L, N) ← sB(x)
5: C : (B, L, N) ← sC(x)
6: Δ : (B, L, D) ← τΔ⁢(Parameter + sΔ(x))
7: A¯, B¯ : (B, L, D, N) ← discretize(Δ, A, B)
8: y ← SSM⁢(A¯, B¯, C)⁢(x)         ▷ Time-varying: recurrence (scan) only
9: return y

比较和改进

  • 时间不变性 vs. 时间可变性:算法 1 中,模型在序列处理中的参数固定,适合规律性较强的任务。算法 2 通过将KaTeX parse error: Unexpected character: '?' at position 1: ?̲?、C 和 Δ 与输入关联,提供了对动态序列的适应能力。
  • 选择机制:算法 2 的选择机制让模型能够根据不同输入自动调整,从而改进对变化序列的适应性。这种选择机制的加入,虽然增加了计算复杂度,但显著提升了对不同输入特征的捕捉能力。

Mamba架构

Mamba 架构在设计上结合了选择性状态空间模型(SSM)和多层感知机(MLP)模块,构建出一种高效处理长序列的模型,尤其适用于减少计算成本并提升性能稳定性。该架构通过门控机制和选择性扫描操作,优化了对长序列的处理,且在多种应用中展现了良好的效果,如下图所示

在这里插入图片描述

架构关键要素:

  1. H3 架构基础
    H3 架构主要由状态空间模型(SSM)和卷积层组成,具备高效处理长序列的特点。这里的 SSM 是一种递归模块,旨在维持序列的状态,通过输入动态更新模型状态。H3 的主要特点包括以下几点:
  • SSM 层(State-Space Model Layer):这个模块通过状态空间形式来描述模型行为,用于跟踪输入序列的状态并进行递归预测。每个时间步都会更新状态,使模型能够持续记忆前序信息。
  • 卷积层(Convolution Layer):卷积层作用于输入序列,将其转换成适合 SSM 处理的状态序列。通过卷积层的应用,H3 可以将序列的局部特征提取为 SSM 所需的输入,有助于提升序列间依赖信息的表达能力。
  1. 门控 MLP 模块
    门控 MLP 模块(Gated MLP)是 Mamba 的另一个核心组件,用于增强网络的非线性表达能力,主要由多层感知机(MLP)结构和门控机制组成。
  • 多层感知机(MLP):MLP 部分负责将输入映射到高维特征空间,通过多个全连接层(fully connected layers)加深网络的特征捕获能力。为了提升性能,该模块引入了 SwiGLU(Swish-Gated Linear Units)变体,增加了更复杂的激活模式。
  • 门控机制(Gating Mechanism):通过引入门控结构,模块可以选择性地通过或过滤掉输入信息。Swish 或 SiLU 激活函数通常用于门控层,能够动态调节信息流动,提高信息的非线性表达效果。这种选择性过滤操作有助于模型在处理长序列时聚焦于关键输入,避免噪声干扰。
  1. 曼巴块的设计
    Mamba 块是 H3 架构和门控 MLP 的结合体,通过将两者与额外的 SSM 层组合成一个模块,形成了 Mamba 独有的架构。以下是 Mamba 块的设计细节:
  • 层级结构:在 Mamba 模块中,SSM 层、卷积层、门控 MLP 层依次排列。每个模块间通过规范化(Normalization)和残差连接(Residual Connections)连接起来,提升了模型的稳定性和优化效果。
  • 去除乘法门和增加激活函数:相比 H3 模块,Mamba 模块中去除了第一个乘法门,用激活函数(如 SiLU)取而代之。这样能够在保留 H3 模块基本结构的前提下,提升 MLP 的处理灵活性。
  • 附加的 SSM 层:Mamba 模块增加了一层 SSM,在主分支中引入状态空间模型,使得模块具备更强的序列处理能力。这种设计在模块内部构建了多层次的状态跟踪能力,使得模型能有效管理和优化长序列中的重要上下文。
  1. 模块整合与选择性机制
    Mamba 架构通过堆叠多个 Mamba 块,形成具有选择性扫描操作的深度网络结构。该结构能够在长序列中选择性记忆或忽略特定输入,具体机制如下:
  • 选择性扫描操作(Selective Scan Operation):这是 Mamba 在长序列处理中的一项关键创新。通过选择性扫描操作,模型可以控制不同模块对输入序列的敏感度,从而聚焦于关键部分而忽略不重要的信息。这一机制类似于门控机制,但在更高层次上对序列信息进行选择性过滤。
  • 重置状态功能:选择性扫描还具备状态重置功能。在处理跨序列的数据时,模型可以随时重置当前状态,避免不同序列间的信息干扰。这一功能使 Mamba 更适合多序列输入,能有效避免传统长序列模型中常见的上下文信息污染。
  1. 设计上的优势
  • 可扩展性:Mamba 通过将 H3 和门控 MLP 模块有机组合,使其在大规模数据处理时具有更好的可扩展性。相比于 Transformer 中的多头注意力(MHA)结构,Mamba 的门控结构和选择性机制减少了对全局上下文的依赖,在长序列上具有更高的计算效率。
  • 高效参数化:通过设置固定的扩展因子 E E E,Mamba 在模型输入输出投影中集中了大部分参数,使得它在参数使用上更为高效。此外,SSM 层的参数较少,有助于减少模型的计算开销。

Mamba 变体

  1. 无标记语言模型:MambaByte

传统的变压器模型在处理语言时通常依赖于标记化(tokenization),将文本分解为子词单元。然而,这种方法会导致庞大的词汇表和词嵌入表,并且对少数词或新词的处理有局限性。

MambaByte 通过直接处理原始字节序列来避免标记化过程,具有以下优势:

  • 语言独立性:避免了基于语言规则的标记化过程,可以处理各种语言而不需要语言特定的适配。
  • 消除子词标记化的偏差:避免了常见子词过度表示和稀有子词或新词表示不足的问题,特别适用于语言具有丰富形态变化的情况。
  • 简化预处理:省去了复杂的标记化和词汇管理步骤,减少了预处理流程中的潜在错误。

2.Mamba 专家混合模型(MoE)

MoE Mamba 是将 专家混合(MoE) 技术与 Mamba 架构结合的创新模型,旨在提升 SSMs 在语言建模中的效率和可扩展性。该模型通过交替使用 Mamba 和 MoE 层,能够高效地整合整个序列的上下文,并为每个标记应用最相关的专家,从而在训练效率上相较于前一个版本减少了 2.2 倍的训练步骤,同时保持了竞争力的性能。

3.视觉 Mamba(Vim)

Vim 是 Mamba 的一个变体,专门处理视觉数据。它将 SSM 与视觉数据处理结合,采用双向 Mamba 块进行视觉序列编码,显著减少了与自注意力机制相关的计算需求。经过 ImageNet 分类、COCO 物体检测和 ADE20k 语义分割测试,Vim 在处理高分辨率图像时显示了更高的性能和更低的计算资源消耗,是未来视觉表示学习领域可扩展模型的有力候选。

4.Jamba

Jamba 是由 AI21 Labs 开发的混合变压器与 Mamba 架构的模型,具有 520 亿个参数,是迄今为止创建的最大的 Mamba 变体。Jamba 拥有 256k 标记的上下文窗口,展示了 Mamba 架构在大规模语言建模中的巨大潜力。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

T1.Faker

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值