力压Transformer,详解Mamba和状态空间模型

大家好,大型语言模型(LLMs)之所以能够在语言理解与生成上取得巨大成功,Transformer架构是其强大的支撑。从开源的Mistral,到OpenAI开发的闭源模型ChatGPT,都采用了这一架构。

然而技术的探索从未止步,为进一步提升LLMs的性能,学界正在研发能够超越Transformer的新架构。其中,Mamba模型以其创新的状态空间模型(State Space Model)成为研究的焦点。

本文将介绍Mamba模型及其在语言建模领域的应用,逐步解析状态空间模型的基本概念,并通过丰富的可视化内容,让大家直观地理解这一技术如何有望挑战现有的Transformer架构。 

1.Transformer架构的挑战

先对Transformer架构做一个快速回顾,并指出其存在的一个主要缺陷。

Transformer架构将文本输入视为由一系列token构成的序列:

图片

其核心优势在于,无论面对何种输入,都能追溯到序列中的早期token,以此来推导出其深层的语义表示。

图片

1.1 Transformer的核心组件

Transformer由架构两部分组成:编码器和解码器。编码器负责解析文本,而解码器则负责生成文本。这种结构的结合,使之能够胜任从文本翻译到内容创作的多种任务。

图片

我们可以仅利用解码器部分来创建生成式模型,这种基于Transformer的模型,即生成预训练Transformer(GPT),通过解码器来续写或补全输入的文本,展现出其在文本生成方面的强大能力。

图片

1.2 自注意力的高效训练

单个解码器块由两个主要部分组成,即“掩蔽自注意力机制(Masked Self-attention)”和“前馈神经网络(Feedforward Neural Network)”。

图片

自注意力是这些模型运行如此良好的主要原因,它提供了整个序列的未压缩视图,并加快了训练速度。具体来说,自注意力机制的工作原理是通过创建一个矩阵,该矩阵对序列中的每个token与之前所有token进行比较,并通过计算它们之间的相关性来确定权重。

图片

在训练过程中,自注意力矩阵是一次性整体构建的,这表示不需要依次等待每个token的注意力计算完成,而是可以同时进行整个序列的注意力计算。例如,在处理“我”和“名字”的关联之前,无需先完成“名字”和“是”的关联计算。

这种设计实现了训练过程的并行化,极大地提升了训练速度,使Transformer架构在处理大规模数据集时更加高效。

1.3 训练与推理的矛盾

然而Transformer架构也有其局限性,每当生成新的token时,必须对整个序列的注意力权重重新进行计算,哪怕此前已经生成了若干token。

图片

生成长度为L的序列需要大约L²次计算,随着序列的延长,计算成本会急剧上升。

图片

这种对序列全面重新计算的需求,是Transformer架构的一个主要瓶颈。

接下来,让我们看看传统的递归神经网络(RNN)是如何克服这一推理过程中的效率问题。

1.4 RNN的潜力

递归神经网络(RNN)是一种处理序列数据的网络结构。在序列的每个时间点,RNN接收两个输入:当前时间点t的输入数据和上一个时间点t-1的隐藏状态,以此来计算下一个隐藏状态并预测输出结果。

RNN具有循环机制,能够将历史信息传递至下一步,类似于将每一步的信息“串联”起来。这种机制可以通过可视化的方式“展开”,以便更清晰地理解其工作原理。

图片

在生成输出时,RNN仅依赖于前一步骤的隐藏状态和当前的输入数据,避免了像Transformer那样需要重新计算整个序列的历史隐藏状态。

正因如此,RNN在进行推理时速度较快,因为它的计算量与序列长度呈线性关系,理论上能够处理无限长的上下文。

举例来说,将RNN应用于之前的输入文本:

图片

每个隐藏状态都是对之前所有状态的压缩汇总。

但这里存在一个问题:随着时间的推移,比如在生成名字"Maarten"时,最后一个隐藏状态可能已经丢失了对"Hello"的记忆,因为RNN在每一步只考虑了前一个状态的信息。

此外,RNN的这种顺序依赖性也导致了另一个问题:它的训练过程无法并行化,必须按顺序逐步进行。

图片

与Transformer相比,RNN在推理速度上具有优势,但在训练并行化方面却存在不足。

图片

这就引出了一个问题:能否找到一种结合了Transformer训练并行化优势和RNN线性推理优势的架构,答案是肯定的,Mamba模型就是。在深入了解Mamba架构之前,先来了解状态空间模型的世界。

2.状态空间模型(SSM)

状态空间模型(SSM),像Transformer和RNN一样,处理信息序列,如文本和信号。在这一部分中,我们将了解SSM的基础知识以及其与文本数据的关系。

2.1 什么是状态空间

状态空间模型是一种通过数学方法全面描述系统状态的方式,包含了描述系统所必需的全部最小变量。

简单来说,就像在迷宫中寻找路径,状态空间就是那张展示所有可能位置(即状态)的地图。在这张地图上,每个点都代表一个独特的位置,并且携带了如距离出口远近等具体信息。

进一步简化这个概念,可以将“状态空间表示”理解为这张地图的提炼,它不仅告诉我们当前所在的位置(即当前状态),还展示了可能的目的地(未来状态),以及如何通过特定的行动(比如右转或左转)达到下一个状态。

图片

虽然状态空间模型依赖方程和矩阵来捕捉系统的行为,但其核心目标是追踪系统的位置、可能的移动方向及其变化路径。

在这个模型中,用以描述状态的变量,如示例中的X和Y坐标或者到出口的距离,统称为“状态向量”。

图片

这听起来有点熟悉,因为在语言模型中,类似的嵌入或向量经常用来描述输入序列的“状态”。例如,你当前的位置状态就可以通过一个向量来表示:

图片

在神经网络的语境下,系统的“状态”通常指的是其隐藏状态,这在生成新token的过程中,尤其是在大型语言模型的背景下,扮演着至关重要的角色。

2.2 状态空间模型(SSM)

状态空间模型(SSM)是一种描述系统状态并预测其未来状态的模型,它能够基于当前状态和特定输入来推断接下来可能的状态变化。

在传统方法中,状态空间模型(SSM)在特定时间点 t 的操作流程如下:

  • 映射输入:SSM首先将输入序列 x(t) 映射到潜在的状态表示 h(t) 。以迷宫为例,当你向左下方移动时,这一动作会被转化为与出口的距离以及具体的坐标位置。

  • 预测输出:接着,SSM利用这些状态信息来预测输出序列 y(t) 。比如,为了更快接近出口,SSM可能会建议你再次向左移动。

不同于传统的离散序列输入方法,如简单的一步移动,我们采用的是连续序列输入,并预测其输出序列。

图片

在状态空间模型(SSM)中,认为动态系统,例如在三维空间中移动的物体,可以通过其在特定时间点的状态,通过两个核心方程来预测。这两个方程是状态空间模型的基石。

图片

我们的目标是通过解决这些方程,发现统计原理,利用观测到的数据——包括输入序列和之前的状态——来预测系统的未来状态。我们追求的是找到一种状态表示 ℎ(𝑡) 能够将输入序列有效地转换为输出序列。

图片

为了便于理解,本文会对这两个方程进行颜色编码,以便能够迅速识别和引用。

状态方程通过矩阵A来展示状态的变化过程,同时矩阵B则揭示了输入如何对状态产生影响。

图片

简单来说,ℎ(𝑡) 代表在特定时间t的状态,而 x(t) 则表示相应的输入值。

进一步地,输出方程通过矩阵C将状态转化为输出,矩阵D则说明了输入如何直接影响输出。

图片

需要指出的是,这些矩阵A、B、C和D不仅是方程中的重要参数,还是可以通过学习过程进行调整的。

可视化这两个方程,得到以下架构:

图片

接下来逐步深入,用通俗易懂的语言来解析这些矩阵是如何在学习和理解过程中发挥作用的。

首先,设想接收到一系列输入信号 x(t) ,这些信号首先通过矩阵B进行处理,矩阵B的作用是展示输入如何影响整个系统。

图片

经过矩阵B的处理后,得到更新后的状态,这可以类比于神经网络中的隐藏层状态,它蕴含了环境的关键信息。接下来,将这个状态与矩阵A相乘。矩阵A揭示了系统内部状态之间的相互联系,代表了系统的内在动态。

图片

值得注意的是,矩阵A在状态表示生成之前就已经被应用,并且在状态更新之后还会继续发挥作用。

进一步,矩阵C负责将状态转换为最终的输出。

图片

而矩阵D则提供了从输入到输出的直接映射,这在某些情况下也被称作"跳过连接"(skip connection)。

图片

然而,尽管矩阵D的功能类似于跳过连接,但在状态空间模型(SSM)中,通常不将其视为包含跳过连接的模型。这种设计使得SSM在处理信息时更加直接和高效。

图片

重新聚焦于状态空间模型(SSM)的核心要素:矩阵A、B和C。

图片

通过重新审视并优化原始方程,可以用更生动的方式展示每个矩阵的特定功能。

图片

就像之前所说的,这些方程的共同目的是利用观察到的数据来预测系统的状态。考虑到输入通常是连续的,SSM主要采用的是连续时间的表示方法。

2.3 从连续到离散信号

处理连续信号以找到状态表示 h(t) 在技术上是一项挑战。特别是当我们面对的是离散输入,比如文本序列,需要将模型转换为离散形式。

为此,我们采用零阶保持技术。这项技术的操作原理相当直观:每当接收到一个离散信号,我们就保持其值,直到下一个离散信号的到来。这样,就生成了一个连续信号,能够被状态空间模型(SSM)所使用:

图片

这个连续信号的保持时间由一个新的可学习参数 决定,它代表了输入信号的分辨率。

有了这个连续信号,就能够产生连续的输出,并在每个输入时间步长处进行采样。

图片

这些采样的值就是离散化输出,在数学上零阶保持的应用可以这样表达:

图片

这种转换允许从连续的SSM过渡到离散的SSM,不是从函数到函数的映射,而是从序列到序列的转换,即 xₖ → yₖ

图片

在这个过程中,矩阵A和B现在代表了模型的离散化参数。使用下标k来代替t,以更清晰地区分离散化和连续化的SSM。

需要注意的是,在训练过程中仍然保留矩阵A的连续形式,而不是其离散化版本。这意味着在训练时,连续表示会被转换成离散形式。

现在已经建立了离散表示的公式,接下来探索如何实际计算这个模型。

2.4 递归表示

我们的离散化状态空间模型(SSM)允许在特定的时间步长上处理问题,而不是在连续信号中。正如之前提到的循环神经网络(RNN)一样,递归方法在这里同样适用。

将视角转向离散时间步长而非连续信号时,问题可以重新定义如下:

图片

在每个时间步长,我们首先计算当前输入 (Bxₖ) 对先前状态 (Ahₖ₋₁) 的影响,然后基于这个更新后的状态来预测输出 (Chₖ)

图片

这种处理方式应该有些熟悉,因为这与之前处理RNN的方法相似。

图片

可以进一步展开这个过程:

图片

注意如何运用这种离散化模型,借鉴了循环神经网络(RNN)的核心处理方式。

这种方法让我们得以享受RNN的快速推理能力,但同时也不得不面对其训练周期较长的不足之处。

2.5 卷积表示

状态空间模型(SSM)的另一种表达方式是通过卷积来实现。回想一下,在传统的图像识别任务中,使用滤波器(卷积核)来提取图像的特征:

图片

由于处理的是文本而非图像,需要采用一维卷积的视角:

图片

SSM中的卷积核是从模型公式中派生出来的,用于捕捉序列数据中的局部依赖关系:

图片

接下来看看这个卷积核在实际应用中是如何工作的。类似于卷积操作,可以利用SSM的卷积核来逐个处理文本中的标记序列,并计算输出:

图片

这个过程也展示了填充(padding)对输出的影响。这里调整了填充的顺序以便于可视化,通常会在句子的末尾添加填充。

随着卷积核的移动,它继续对序列的下一部分进行计算:

图片

在最后一步,可以观察到卷积核在整个序列上的作用效果:

图片

将SSM以卷积的形式表示的一个主要优势在于,它能够像卷积神经网络(CNN)一样进行高效的并行训练。然而由于卷积核的大小是固定的,SSM在推理速度上可能不如RNN那样快速和灵活。

2.6 三种表示

这三种状态空间模型(SSM)的表示形式——连续的、递归的和卷积的,各有其特点和适用场景:

图片

首先,递归SSM具有高效的推理能力,而卷积SSM则以其可并行化的训练过程受到青睐。

利用这些不同的表示,可以灵活地根据任务需求选择最合适的模型。例如,在训练阶段,可以选择使用卷积表示以实现并行化;而在推理阶段,则可以切换到递归表示以获得更快的响应速度:

图片

这种灵活运用不同表示的模型被称为线性状态空间层(LSSL)。

这些模型有一个共同的重要特性,即线性时不变性(LTI)。LTI意味着SSM的参数A、B和C在所有时间步长上都是恒定不变的。

换句话说,无论输入序列如何变化,这些参数的值始终如一,提供了一种静态且非内容感知的模型表示。

在深入探讨Mamba如何解决这一问题之前,我们先来关注这个难题的核心部分——矩阵A。

2.7 矩阵A的重要性

可以说,SSM公式中最重要的就是矩阵A。正如之前在递归表示中所看到的,矩阵A负责捕捉先前状态的信息,以便构建新的状态。

图片

本质上,矩阵A产生隐藏状态:

图片

它的作用可以简单理解为:在只记忆几个最近标记的同时,捕捉到目前为止所有标记之间的差异。这一点在递归表示中尤为明显,因为它只考虑前一个状态。

那么,如何创建一个既能够保留大量记忆(上下文信息),又能够捕捉到关键信息的矩阵A呢?

这里,引入了一种称为“饥饿的河马”(HiPPO)的技术,即高阶多项式投影算子。HiPPO的目标是将所有迄今为止接收到的输入信号压缩成一个系数向量。

图片

HiPPO通过矩阵A构建一个状态表示,这个表示不仅能够很好地捕捉到最近的标记,还能使旧标记的影响逐渐减弱。其具体的数学表达式如下:

图片

假设有一个方阵A,它在模型中扮演着至关重要的角色。

图片

采用HiPPO(高阶多项式投影算子)来构建矩阵A,这种方法比随机初始化矩阵A要有效得多。HiPPO能够更准确地重建新信号(即最近的标记),同时保留对旧信号(初始标记)的记忆。

HiPPO的核心思想在于,它能够生成一个能够记住其历史状态的隐藏状态。

在数学上,HiPPO通过追踪Legendre多项式的系数来实现这一点,这使得它能够近似地捕捉到所有先前的历史信息。

将HiPPO技术应用于我们之前讨论过的递归和卷积表示,可以帮助我们更好地处理长距离依赖问题。这种应用的结果是结构化状态空间序列(S4),这是一种专为处理长序列设计的SSM。

S4模型由以下三个主要部分组成:

  • 状态空间模型:提供基本的动态系统描述。

  • HiPPO:专门用于处理长距离依赖。

  • 离散化:将递归和卷积表示结合起来,形成适合离散时间步的模型。

图片

根据选择的表示方式(递归或卷积),S4模型具有多种优势。它还可以通过构建HiPPO矩阵来高效地存储记忆,从而处理长文本序列。

3.Mamba:一种选择性状态空间模型

现在已经全面掌握了理解Mamba独特之处所需的基础知识。状态空间模型在模拟文本序列方面具有潜力,但也存在一些固有的局限性,这是我们希望克服的。

接下来,将介绍Mamba的两大创新:

  • 选择性扫描算法:这种算法赋予模型智能筛选信息的能力,能够识别并过滤掉无关紧要的内容。

  • 硬件感知算法:该算法通过并行扫描、内核融合和重新计算等技术,优化了中间结果的存储效率。

这两项技术的结合,创造出了选择性状态空间模型,简称S6模型。S6模型的设计理念与自注意力机制相似,可以构建出功能强大的Mamba模块。

3.1 解决的问题

虽然状态空间模型,包括高级的S4(结构化状态空间模型),在模拟文本序列方面具有其优势,但它们在执行某些关键任务时表现并不尽如人意。特别是在需要模型能够灵活关注或忽略特定输入的情况下,这些模型的能力显得捉襟见肘。

可以通过两个合成任务来说明这一点:选择性复制和归纳头。这两个任务能够直观地反映出模型在处理特定输入时的不足。

在选择性复制任务中,状态空间模型(SSM)面临的挑战是识别并按特定顺序复制输入序列的片段。

图片

然而,无论是递归型还是卷积型的SSM,在这项任务上的表现都不尽人意。原因在于SSM的线性时不变特性:对于生成的每个标记,矩阵A、B和C都是固定不变的。

这种固定性导致了SSM在内容感知推理方面的不足。由于A、B和C矩阵的一成不变,SSM对所有标记的处理都是平等的,无法根据标记的实际内容进行差异化处理。这对于我们希望模型能够对输入(尤其是提示信息)进行深入推理和响应的目标来说,无疑是一个重大障碍。

状态空间模型(SSM)在执行归纳头任务时同样面临挑战。这项任务要求模型能够识别并再现输入中的特定模式。

图片

例如,在一次性提示的场景中,希望模型在遇到每个“Q:”之后能够自动提供相应的“A:”响应。然而SSM由于其固有的时间不变性,难以从历史信息中选择性地回忆并应用先前的标记。

这一点通过矩阵B的表现尤为明显。无论输入x如何变化,矩阵B始终保持一致,与输入内容无关:

图片

同样,矩阵A和C也不受输入影响,始终固定不变。这种静态特性限制了SSM在处理需要动态调整注意力的任务时的表现。

图片

与此相对,Transformer模型能够根据输入序列动态调整其注意力焦点,灵活地“查看”或“关注”序列的不同部分。这种能力使得Transformer在执行归纳头任务时相对容易。

SSM在这些任务上的不足,揭示了其时间不变性所带来的问题。A、B和C矩阵的静态特性,导致了模型在内容感知方面的局限,无法像Transformer那样灵活地处理信息。

3.2 选择性保留信息

状态空间模型(SSM)的递归表示虽然以小状态压缩了整个历史,从而提高了效率,但与能够全面捕捉历史信息的Transformer模型相比,其能力还是有所不及。Transformer模型通过注意力矩阵避免了信息压缩,从而在处理信息时更为强大。

Mamba模型的设计理念是在保持状态小巧的同时,实现与Transformer相媲美的强大功能:

图片

它通过选择性地压缩输入数据到状态中,实现了这一点。在处理输入句子时,Mamba能够识别并忽略那些意义不大的信息,如停用词,从而只保留关键内容。

为了实现这种选择性压缩,Mamba模型的参数需要根据输入动态调整。

图片

在传统的结构化状态空间模型(S4)中,矩阵A、B和C的维度是静态的,与输入无关。

图片

而Mamba则不同,它将序列长度和批量大小纳入考量,使矩阵B和C,甚至步长∆,都与输入相关联:

图片

这种设计意味着,对于每个输入标记,Mamba都会使用不同的B和C矩阵,从而解决了内容感知的问题。

注意:尽管矩阵A保持不变,以维持状态的稳定性,但通过B和C对状态的影响却是动态的,能够根据输入灵活调整。

Mamba模型通过调整步长∆的大小,可以灵活选择在隐藏状态中保留或忽略的信息。

较小的步长∆有助于忽略某些具体单词,更多地利用先前的上下文信息;而较大的步长∆则使模型更加关注当前输入的单词,而不是依赖于上下文:

图片

3.3 扫描操作优化

由于Mamba模型中的矩阵现在是动态变化的,它们不再适用传统的卷积计算方法,因为卷积计算依赖于固定的核心。这意味着我们只能依赖递归表示,而递归意味着失去了卷积所提供的并行处理优势。

然而,Mamba并未止步于此。为了实现并行化,Mamba探索了一种新的计算输出的方法:

图片

在递归表示中,每个状态是由前一个状态(乘以矩阵A)与当前输入(乘以矩阵B)相加得到的总和。这个过程称为扫描操作,通常可以通过简单的for循环来实现。

虽然从表面上看,由于每个状态的计算依赖于前一个状态,实现并行化似乎是不可能的任务。但Mamba通过引入并行扫描算法,巧妙地解决了这一难题。

Mamba的并行扫描算法基于一个关键的假设:操作的执行顺序并不重要,这是由于数学上的结合律所保证的。基于这一原理,我们可以将序列分解为部分进行计算,然后再逐步将这些部分组合起来。

图片

正是这种动态的矩阵B和C,结合了并行扫描算法,共同构成了Mamba的选择性扫描算法。这一算法不仅捕捉了递归表示的动态特性,还实现了快速的信息处理。

3.4 硬件感知算法

现代GPU虽然拥有高效的SRAM和容量更大的DRAM,但它们之间的数据传输速度有限,成为制约性能的瓶颈。频繁在这两种内存之间复制数据会严重影响处理速度。

图片

为了解决这一问题,Mamba采用了与Flash Attention相似的策略,通过减少从DRAM到SRAM以及反向的数据传输次数来优化性能。Mamba通过内核融合技术,将多个计算步骤合并到一个内核中,避免了中间结果的频繁写入和读取,从而持续进行计算直至完成。

图片

通过可视化Mamba的基础架构,可以清晰地看到DRAM和SRAM的分配情况:

图片

在Mamba的设计中,以下步骤被融合为一个高效的内核处理流程:

  • 步长∆的离散化步骤

  • 选择性扫描算法的应用

  • 与矩阵C的乘法操作

Mamba的硬件感知算法还包括一个关键的优化措施——重新计算。

在传统的计算过程中,中间状态需要被保存以供后续使用。然而,Mamba选择不在内存中保存这些状态,而是在反向传播阶段根据需要重新计算它们。

这种方法虽然初看似乎效率不高,但实际上,与从速度较慢的DRAM中读取大量中间状态相比,重新计算的成本要低得多。

这种架构通常被称为选择性状态空间模型或S6模型,因为它本质上是使用选择性扫描算法计算的S4模型。

Mamba的这种架构,被形象地称为选择性状态空间模型,简称S6模型。它本质上是对S4模型的一种改进,通过选择性扫描算法进行计算,不仅提升了性能,还保持了模型的灵活性和可扩展性。

这一创新架构的详细描述可以在Gu, Albert和Tri Dao的研究文章中找到,文章标题为“Mamba: Linear-time sequence modeling with selective state spaces”,发表于arXiv预印本服务器。

图片

选择性状态空间模型(摘自:Gu、Albert 和 Tri Dao。"Mamba:ArXiv preprint arXiv:2312.00752 (2023).)

3.5 Mamba模型的模块化与性能优势

目前为止,我们所了解的选择性状态空间模型(SSM)可以作为一个独立的模块实现,类似于解码器中自注意力机制的实现方式。

图片

正如解码器可以堆叠多个层以深化处理一样,Mamba模型也可以通过堆叠多个Mamba块来增强其功能,每个块的输出直接作为下一个块的输入:

图片

这一过程始于输入嵌入的线性投影,以扩展数据维度,然后通过卷积层进行预处理,避免对独立标记的单独计算。

Mamba模型具有以下特点:

  • 递归状态空间模型(SSM)通过离散化技术构建,以实现高效的信息处理。

  • 矩阵A采用HiPPO初始化方法,以捕捉和维持长距离依赖关系。

  • 选择性扫描算法,智能地压缩信息,仅保留关键数据。

  • 硬件感知算法,优化计算流程,加速模型运行。

在代码实现层面,我们可以进一步扩展这一架构,研究端到端示例的构建细节:

图片

例如,引入归一化层和softmax函数来选择输出标记,这些都是模型性能优化的关键环节。

图片

综合这些元素,Mamba模型不仅实现了快速的推理和训练,还具备了处理无界上下文的能力。实际应用中,研究者发现Mamba模型在性能上能够匹敌甚至超越同等规模的Transformer模型。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

python慕遥

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值