结构化状态空间模型的直观解释(第二部分)

第 2 部分 - 面向图像、视频和时间序列的 Mamba 状态空间模型(欢迎来到雲闪世界。)

添加图片注释,不超过 140 字(可选)

年代状态空间模型几十年来为许多工程学科所熟知,现在在深度学习中首次亮相。在我们探索 Mamba 选择性状态空间模型及其最新研究成果的过程中,了解状态空间模型至关重要。而且,正如工程中经常出现的情况一样,正是细节让理论概念在实践中得以应用。除了状态空间模型之外,我们还必须讨论如何将它们应用于序列数据、如何处理长距离依赖关系以及如何通过利用某些矩阵结构来有效地训练它们。

结构化状态空间模型为 Mamba 构建了理论基础。然而,它们与系统理论和高级代数的联系可能是采用这一新框架的障碍之一。 因此,让我们分解一下,确保我们理解关键概念并将它们可视化,以阐明这一新旧理论。

即使您最终没有使用状态空间模型,了解一些技巧(例如为什么我们需要加速矩阵乘法以及如何利用某些矩阵结构来实现这一点)也肯定会提升您作为工程师或开发人员的技能。

在第 1 部分中,我们回顾了循环神经网络 (RNN) 和 Transformers 的优缺点,以说明为什么我们需要一种新的模型架构。 我们说 RNN 推理速度快,但训练速度慢,而 Transformer 训练速度快,但推理速度慢。 我们想要找到一个训练和推理速度都很快的模型,同时要与 Transformer 的性能相媲美。

2. 状态空间模型简介 Mamba 建立在通过学习各种矩阵将状态空间模型用于深度学习的理念之上。因此,在探索“结构化”部分的含义之前,让我们先简单了解一下状态空间模型。 状态空间模型可以定义为连续时间表示以处理连续信号,也可以离散化以处理离散数据序列。 我们最感兴趣的是离散状态空间模型,因为像循环神经网络 (RNN) 和 Transformers 一样,离散 SSM 处理数据序列,例如文本标记或模拟时间信号的样本。 2.1 状态空间模型的连续时间表示 连续时间状态空间模型描述了通过状态系统传播的输入信号与产生的输出信号之间的关系。

图 1:状态空间模型的高级功能。

这是从系统论中借用的一个思想。输出取决于输入和系统的当前状态,而当前状态取决于先前的状态和输入。 这种关系可以通过两个简单的方程有效地表述出来,其中A、B、C和D是矩阵(稍后我们会看到),x(t)是输入信号,y(t)是输出信号,h(t)和h'(t)分别是当前状态和更新状态。

方程 1:连续时间 SSM 的状态和输出方程。

矩阵D将输入x(t)进行变换,并将其映射到输出y(t),这通常包含在经典 SSM 的输出方程中。 然而,当将 SSM 应用于深度学习时,我们会删除此变换并将其建模为简单的跳跃连接,从而简化 SSM。 我们可以用如下框图来表示这些方程:

图 2:连续时间状态空间模型的框图

这是 SSM 的连续时间表示。但有两个困难:

  1. 找到模型状态h(t)的解析解具有挑战性,并且

  2. 因为我们使用计算机工作,所以我们通常处理离散信号,而不是连续信号。

2.2 状态空间模型的离散化 因此,我们实际上不是将函数x(t)映射到函数y(t),而是将序列x[k]映射到序列y[k]。这意味着我们需要一个可以处理离散信号的离散状态空间模型。

图 3:离散状态空间模型的高级功能。

为此,我们将矩阵A和B离散化。这得到了状态空间模型的离散表示,现在包括离散状态h[k]。

方程2:离散化SSM的状态和输出方程。

为了离散化我们的状态空间模型,我们使用了系统理论中经常使用的另一个技巧:我们使用零阶保持 (ZOH) 模块将离散信号x[k]转换为连续时间信号x(t),然后将其输入到连续 SSM 中。SSM 将生成连续输出y(t),然后我们对其进行采样以获得离散输出序列y[k]:

图 4:离散状态空间模型的框图。

要离散化状态空间模型,我们必须离散化矩阵A和B。最初,大多数基于深度学习的状态空间模型(如LSSLS4)都使用了双线性离散化,而后来的模型(如DSS和 Mamba)则使用了 ZOH 离散化。现在,我们关心的只是我们需要离散化,而且可能有多种方法可以做到这一点。

公式 3:零阶保持(ZOH)离散化与双线性离散化。

A和B是 SSM 的连续时间矩阵,I是单位矩阵,Δ 是表示输入分辨率的步长,在深度学习设置中是学习到的。 请注意,离散化允许我们使用 SSM 的时间连续公式并将其应用于计算机中使用的离散信号。这意味着我们变得具有分辨率不变性,因为我们学习了底层的时间信号。查看 S4ND 论文了解更多信息。 2.3 SSM 的循环和卷积表示 您是否注意到状态方程中的递归性,因为h[k]依赖于h[k-1]?事实上,状态空间模型在这方面非常像 RNN:

图 5:离散状态空间模型的递归表示。

矩阵A、B和C是时不变的。这意味着每个输入和每个状态都使用相同的参数进行变换,而不管时间步长如何。线性时不变系统或短 LTI 系统具有一个非常巧妙的属性:它们可以表示为卷积!

方程4:离散化状态空间模型的卷积核

核由矩阵C和离散化矩阵A和B构成。请注意,核的长度与输入序列相同。核与输入x卷积以获得输出y。为了使输出的序列长度与输入相同,通常会对输入进行填充。

图 6:离散状态空间模型的卷积表示示例。

这太棒了,因为我们知道卷积是可并行的,并且 GPU 和其他硬件加速器经过优化可以快速计算。 总结本节,利用线性时不变系统的特性,我们最终得到了 3种不同的方式来表示 SSM:

图 7:状态空间模型的不同表示。

SSM 的循环表示和卷积表示是相同的,因此我们可以决定使用其中一个。 在训练期间,我们会选择卷积表示,因为我们可以通过训练数据提前访问整个序列x[k]和标签y[k] 。 在推理过程中,我们会选择循环表示,因为通常我们只能访问过去的样本(例如自回归文本生成)。 请注意,推理与序列长度成线性比例,因为我们的下一个输出仅取决于先前的状态和当前输入。 我们现在有了一个训练高效、推理高效的模型。太棒了!😎 3. 建立长远关系 矩阵A负责从前一个状态创建下一个状态。系统的状态是考虑过去状态和新输入的多次更新的结果。可以说,状态包含输入的整个历史信息。不难想象,对于非常长的序列,准确记住所有输入可能是一项艰巨的任务。 那么,我们如何处理长程依赖关系呢? 3.1 国家规模很重要 我们首先来谈谈状态、状态的压缩和推理的效率。

图 8:RNN、SSM、Transformer 和所需模型的效率与性能

从理论上讲, RNN 和 Transformer 都可以处理永无止境的数据序列(有时人们会说:这些模型是不受约束的)。RNN通过将整个历史记录压缩为单个低维状态来实现这一点。这对于推理非常有效,因为现在下一个状态仅取决于前一个状态和当前输入。RNN 的性能取决于模型学习进行这种压缩的能力。此外,对于长序列,尤其是由于梯度消失,这不是一项容易实现的任务,即使对于使用 GRU 和 LSTM 等门控机制的更高级 RNN 也是如此。另一方面, Transformer根本不压缩历史记录。它通过其注意力机制考虑整个历史记录(以及未来,如果没有像自回归生成那样被掩盖)。这允许并行训练 Transformer,但同时在推理过程中,随着序列长度的增加,内存和计算需求也会激增。 ❓ 那么,再问一遍:我们该如何在高度压缩的状态下处理长距离依赖关系,而不会在训练过程中遇到困难,同时实现高性能? 一种解决方案:我们使用HiPPO(高阶多项式投影运算符)框架初始化我们的矩阵。 3.2 HiPPO 框架 主要思想是从状态方程中参数化矩阵A和B,以便我们可以重建任意给定点的历史(即从时间t=0到t=now 的输入信号)。具体来说,如果状态表示为某些正交多项式(例如勒让德多项式)的某些系数,我们可以简单地评估多项式以重建到当前时间的输入信号。 例如:我们有一个输入函数x(t)。在任何给定的时间步长ti,我们希望找到一个重构函数g(ti) ,它对于ti之前的所有t ,在 t<ti 时近似x(t)。近似值g(ti)由系数c(ti)完整描述,该系数可以插入勒让德多项式以重构g(ti)。

图 9:HiPPO 框架中输入函数 x(t) 的近似值 g(t)。

理解这一点很重要:我们不必只访问输入的压缩表示,而是可以随时重建输入。尽管如此,它还是一个具有线性复杂性所有优点的循环模型。 这些系数会随着时间而变化,因为对于每个ti,我们都有不同的系数c(ti)。那么,我们如何得到这些系数呢?事实证明,c(t)可以定义为一个常微分方程 (ODE):

方程 5:HiPPO 框架系数的 ODE

利用一些数学技巧(例如假设x和g正交)并选择不同时间步长的权重µ(t),可以完全定义此问题。我们可以求解 ODE 并获得满足我们权重选择的A(t)和B(t),并将其用于我们的 SSM。 针对过去样本的不同权重和不同的多项式作为基础,提供了不同的 HiPPO 矩阵作为解决方案。

  • LegT:带滑动窗口的勒让德多项式

  • LegS:具有常数权重的勒让德多项式。

  • LagT:具有指数衰减的拉盖尔多项式

LegS 是后续研究中使用最突出的一个。以下是矩阵A的解:

等式 6:N=4 的 HiPPO 矩阵定义和示例。

一个简单的 Python 实现可能看起来像这样:

 
 

def get_hippo_A ( N:int ) -> np.ndarray: A = np.zeros((N,N)) 对于范围(N)内的 n :对于范围(N)内的k :如果n > k: A[n,k] = (2 *n+ 1)** 0.5 *(2 *k+ 1)** 0.5 elif n == k: A[n,k] = (n+ 1 )返回A

前面提到的 LSSL 确实表明,如果使用 HiPPO 矩阵初始化,则可以训练 SSM,并且确实表明其性能甚至超过了所示实验中的 Transformer。然而,其昂贵的计算和内存要求使得 LSSL 无法作为通用序列模型。直到引入 S4(结构化状态空间序列)模型,这个问题才得到解决,我们将在下一章中讨论这个问题。 4. 为提高效率对状态矩阵施加结构 现在是时候讨论结构化状态空间模型中的结构到底是什么了。我们首先尝试获得一些非常基本的直觉,即结构是什么以及它为什么重要,然后我们将仔细研究 S4 模型。 4.1 什么是结构化矩阵以及它们为什么重要? 让我们考虑一个由NxN 个元素组成的标准方阵。我们将重点关注两个方面:

  1. 我们需要多少内存来表示这个矩阵,

  2. 我们能多快地将两个矩阵相乘,其中快速可以由多种因素决定,例如能够并行化计算,利用冗余,以及首先要处理的值较少。

我们需要多少内存? 让我们从存储矩阵的内存需求开始。我们可以有不同类型的矩阵。有时它们已经是某种形式,有时我们可以利用线性代数的规则将它们变成某种形式。让我们看看一些常见的矩阵类型:

图 10:几种常见矩阵类型存储的值的数量。

我们发现,对于密集矩阵,在最坏的情况下,我们需要将N² 个值存储在内存中。不仅如此,我们还必须将N² 个值从 CPU 加载到 GPU,然后在那里应用某些矩阵运算。 请注意,我们需要在 CPU 端的内存中存储多少值(即在 RAM 或磁盘上)与我们在 GPU 中占用多少内存进行计算是有区别的。为了清楚地说明这一点,让我们看一下低秩分解。一些矩阵可以分解为 2 个NxR的低秩矩阵,其中R是这些矩阵的秩。现在我们只需要为两个矩阵中的每一个存储NR值,因此有2NR值。这也意味着我们只需要从 CPU 加载2NR值而不是N² 值到 GPU。但如果我们想用原始矩阵进行计算,我们首先需要在 GPU 内部重建原始矩阵(有人说将它们具体化),这意味着我们仍然需要GPU 缓存中的N²值。 另一方面,对角线或稀疏矩阵等其他表示形式会减少:存储在磁盘上的值和在 GPU 上执行矩阵运算所需的值的数量。 需要多少次计算? 现在让我们考虑一个标准矩阵乘法C=AB,其中A和B是密集矩阵或对角矩阵,而C是矩阵乘法的结果:

图 11:不同矩阵类型的矩阵乘法运算次数

对于两者都是密集矩阵的情况,使用我们都知道的矩阵乘法的标准算法,我们需要N³ 次乘法和(N-1)N² 次加法。 如果A是密集矩阵而B是对角矩阵,则这将减少到N² 次乘法和0 次加法。 如果A和B都是对角线,我们只需要N 次乘法,即对角线元素的逐元素乘法。虽然这在硬件资源方面是可取的,但请注意,对角矩阵作为密集矩阵的表现力可能较差,因为它的值较少。像所有事物一样,这是一种权衡。 计算的并行化如果你仔细观察密集矩阵乘法,你可能会注意到,计算输出C 的一行涉及B的所有值,但只涉及A的一行。简单地说,你可以在 GPU 的不同核心上并行计算C的每一行来加快计算速度。这意味着每个核心只需要N²次乘法和(N-1)N 次加法,而不是像单核情况下那样需要N³ 次乘法和(N-1)N² 次加法,因为我们现在有一个向量矩阵乘积,而不是矩阵矩阵乘积。 长话短说,总而言之,使用NxN 方阵的示例,通过尝试利用矩阵的某些属性,我们希望将它们带入一种结构化形式,使得我们需要存储和加载少于N² 的值,并且对于矩阵乘法,我们需要少于N³ 次乘法和(N-1)N² 次加法。 当然,我们还可以采用许多其他矩阵结构,这些结构各有优缺点。在本系列关于 Mamba 状态空间模型的文章中,我们将遇到许多这样的矩阵结构,它们都具有相同的目的:在保持最佳性能的同时减少资源消耗。 4.2 综合起来——S4:结构化状态空间序列模型 S4 旨在通过在相关矩阵上强制执行某种结构来提高 SSM 的效率。通过这样做,它可以应用经过充分研究的代数规则来降低所涉及计算的复杂性并节省内存,尤其是对于卷积视图的核K的重复计算。

图 12:S4 结构化状态序列模型。

与 LSSL 相比,S4 的速度提高了 30 倍,内存使用量减少了 400 倍,同时性能也有所提升。 S4 论文涉及大量高级线性代数,我将在这篇博文中基本跳过。如果你想更深入地研究数学和代码实现,我建议你查看《带注释的 S4》。 回想一下,计算SSM 卷积表示的卷积核K时,需要对离散化的方阵A进行重复的矩阵乘法。请记住,核的长度与输入序列x相同,并且矩阵A的幂随着核的每次输入而增加。

方程6:离散状态空间模型的卷积核。

要清楚的是,长度为L=101的序列将具有100 次方的矩阵!而99和98和……你明白了。当然,我们可以重复使用之前的结果,但这仍然需要大量计算。 S4 的主要目标是通过对矩阵A施加一定的结构来减少计算核K所需的资源数量。 理想情况下,矩阵A应该是一个对角矩阵,因为正如我们在图 11 中看到的那样,两个对角矩阵相乘会产生另一个对角矩阵,只需要N 次乘法。 由于 HiPPO 矩阵无法以数值稳定的方式对角化(如 S4 论文中证明的那样),因此下一个最佳方法是将其分解为正态加低秩 (NPLR) 形式,然后可以通过共轭进一步简化为对角加低秩 (DPLR) 形式。 请注意,我们希望保留 HiPPO 矩阵,不要将其更改为另一个可对角化的矩阵,因为我们需要它来处理长程依赖关系。

图 13:比较矩阵 A 的不同表示形式以及计算卷积核 K 的复杂度。

矩阵A为 DPLR 形式允许我们应用许多复杂的技巧。总之,核是在频域中计算的,而 DPLR 形式允许将矩阵重新排列为柯西矩阵,为此存在许多硬件优化算法。 此外,利用 Woodburry 恒等式可以将A的多个幂的计算转化为一个逆运算。最后,应用逆 FFT(快速傅里叶变换)来获得最终的卷积滤波器K。 因此,SSM(A,B,C)以前是A、B和C的函数,现在是函数SSM( Λ- PP*,B,C),其中Λ、 P、B和C是 S4 的可学习参数。 如果您对原始实现感兴趣,这里是S4 存储库的dplr()函数和nplr()函数。 5. 总结 5.1 总结

  • 状态空间模型 (SSM) 是序列模型。

  • SSM 可以描述为循环模型或卷积模型

  • 由于它们仅依赖于当前输入和先前状态,因此可以在其线性扩展的循环表示中快速推断它们。

  • 由于它们的卷积表示可以并行化,因此训练速度很快。

  • 由于它们是 LTI 系统,因此任何时间步骤的每个输入都被平等对待,即由同一个矩阵进行转换。

  • 结构化状态空间序列模型(又名 S4)在其矩阵上使用某种结构来提高 SSM 的速度,从而减少对计算和内存要求的限制。

  • 感谢关注雲闪世界。(Aws解决方案架构师vs开发人员&GCP解决方案架构师vs开发人员)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值