[读论文]Transformers are SSMs

Notation

T T T: Sequence length/ time length
$$:

摘要

虽然transformer一直是深度学习在语言建模方面成功的主要架构,但状态空间模型(ssm),如Mamba,最近被证明在中小规模上与transformer相匹配或优于transformer。这些模型族实际上是非常密切相关的,并在ssm和注意力变体之间发展了一个丰富的理论联系框架,通过对一类经过充分研究的结构化半可分矩阵的各种分解连接起来。状态空间对偶(SSD)框架使我们能够设计一个新的架构(Mamba-2),其核心层是Mamba选择性SSM的改进,速度快了2-8倍,同时在语言建模方面继续与transformer竞争。

介绍

Transformer存在效率问题。例如在训练期间按序列长度进行二次缩放,以及在自回归生成期间需要按序列长度线性的缓存大小。而一类可供选择的序列模型,即结构化状态空间模型(SSMs),在训练期间序列长度呈线性扩展,在生成期间状态大小恒定。同时后者在长程任务上表现出强大的性能,最近在小到中等规模的语言建模上与transformer不相上下或击败了transformer。然而,ssm的发展似乎与社区改进transformer的集体努力无关,例如从理论上理解它们,以及在现代硬件上优化它们。因此,与transformer相比,理解和实验ssm更加困难,从算法和系统的角度来看,训练ssm像transformer一样有效仍然具有挑战性。

本文的主要目标是在结构化ssm和注意力变体之间建立丰富的理论联系。这将使我们能够将最初为transformer开发的算法和系统优化转移到ssm,以建立性能比transformer更好的基础模型,同时在序列长度上更有效地扩展。该方向的一个里程碑式贡献是线性注意力(LA)框架,通过展示二次核注意力的"对偶形式"和特定线性递归之间的等价性,导出了自回归注意力和线性rnn之间的联系。这种对偶性允许新的功能,如具有高效的可并行化训练和高效的自回归推理的能力。同样,本文提供了将线性复杂度ssm与二次复杂度形式联系起来的多个观点,以结合ssm和attention的优势。

状态空间对偶性。所提出的连接结构化ssm和注意力变体的框架,称为结构化状态空间对偶(SSD),是通过对结构化矩阵的抽象来建立的(结构化矩阵是具有次二次参数和乘法复杂度的矩阵)。本文开发了两个广泛的框架来表示序列模型,一个作为矩阵变换,一个作为张量收缩,它们都揭示了对偶性的不同角度。我们的技术贡献包括:

  • 展示了状态空间模型和被广泛研究的结构化矩阵族(半可分离矩阵)之间的等价性(第3节)。这种联系是所提出框架的核心,揭示了ssm的新属性和算法。本文的中心思想是计算状态空间模型的不同方法可以被重新定义为结构化矩阵上的各种矩阵乘法算法
  • 显著改进了线性注意力理论。本文首先通过张量收缩的语言提供了其递归形式的简明证明,然后将其推广到一个新的结构化掩码注意力(SMA)族(第4节)。
  • 连接了SSM和SMA,表明它们有一个彼此对偶的大交集,具有类似SSM的线性和类似注意力的二次形式(第5节)。还证明了任何具有快速递归形式的核注意力方法都必须是SSM。

除了其内在的理论价值外,该框架为理解和改进序列模型开辟了一套广泛的方向。

高效的算法。首先,该框架暴露了新的高效和易于实现的算法,用于计算SSM(第6节)。提出了一种新的SSD算法,基于半可分矩阵的分块分解,同时利用了线性SSM递归和二次对偶形式,在所有主要效率轴上获得了最佳折衷(例如训练和推理计算、内存使用,以及在现代硬件上利用矩阵乘法单元的能力)。SSD的专用实现比Mamba的优化选择性扫描实现快2 - 8倍,同时允许更大的循环状态大小(8倍或更高的Mamba大小,以最小的减慢)。SSD与softmax注意力(FlashAttention-2 (Dao 2024))的优化实现相比具有很强的竞争力,在序列长度为2K时交叉速度提高了6倍,在序列长度为16K时交叉速度提高了6倍。

架构设计。采用ssm等新架构的一个主要障碍是为transformer量身定制的生态系统,如用于大规模训练的硬件高效优化和并行技术。所提出框架允许使用既定的注意力惯例和技术为ssm建立架构设计选择的词汇表,并进一步改进(第7节)。例如,引入了从多头注意力(MHA)到ssm的头的模拟。我们表明这样的Mamba架构是一个多输入SSM (MIS),被证明类似于多值注意力(MVA),并比较了具有不同头部结构的Mamba的其他变体。

我们还使用这些想法对Mamba块进行轻微修改,它允许实现张量并行。主要思想包括引入组值注意力(GVA)头部结构,并将所有数据依赖的投影移动到区块开始时并行发生。

将改进后的并行Mamba块与使用SSD作为内层SSM层相结合,得到Mamba-2架构。本文研究了Mamba-2和Mamba在相同环境下的Chinchilla缩放定律,发现在困惑度和壁钟时间上,它帕累托优于Mamba和Transformer++。在Pile上训练了一系列不同大小的Mamba-2模型,表明它在标准的下游评估中匹配或优于Mamba和开源transformer。例如,在桩上300B token上训练的2.7B参数的Mamba-2优于在相同数据集上训练的Mamba-2.8B, Pythia-2.8B,甚至Pythia-6.9B。

系统优化。SSD框架将ssm和transformer连接起来,使我们能够利用为transformer开发的大量系统优化工作(第8节)。

  • 例如,张量并行(TP)是一种重要的模型并行技术,通过在同一节点上的gpu上划分每一层来训练大型Transformer模型。我们将Mamba-2设计为TP友好型,将每个块的同步点数量减少了一半。
  • 对于非常长的序列,其激活不适合在一个设备上,针对注意力块开发了序列并行性。本文描述了如何通过在设备之间传递循环状态,用序列并行来训练ssm,特别是Mamba-2。
  • 为了对不同长度的示例进行微调,为了获得最佳效率,Transformer需要复杂的技术来删除填充标记并对可变长度序列进行关注。我们展示了如何用可变序列长度有效地训练Mamba-2,不需要填充标记。

第9节对Mamba-2在语言建模、训练效率和一个困难的多查询关联回忆任务上进行了实证验证(Arora, Eyuboglu, Zhang等人2024)。最后,在第10节中,提供了相关的扩展工作,并讨论了该框架开辟的潜在研究方向。
代码见https://github.com/state-spaces/mamba。

2背景及概述

2.1结构化状态空间模型(Structured SSM)

结构化状态空间序列模型(S4)是最近一类用于深度学习的序列模型,与rnn、cnn和经典状态空间模型有广泛的关系。它们的灵感来自于一个特定的连续系统(1),该系统通过隐潜状态ℎ∈R(T,N)映射一个一维序列𝑥∈RT↦→𝑦∈RT。
结构化ssm的一般离散形式采用方程形式(1)
在这里插入图片描述
结构化ssm之所以如此命名,是因为控制时间动态的𝐴矩阵必须是结构化的,以便足够有效地计算这种序列到序列的转换,以便在深度神经网络中使用。最初引入的结构是对角线加低秩(DPLR) 和对角线,这仍然是最受欢迎的结构。

本文使用状态空间模型(SSM)一词来指代结构化SSM。这种ssm有许多风格,与神经序列模型的几种主要范式有很深的联系,如连续时间、递归和卷积模型。我们在下面提供了一个简要的概述,并参考之前的工作以了解更多的背景和细节。

连续时间模型。原始结构化ssm起源于函数𝑥(𝑡)∈R↦→𝑦(𝑡)∈R上的连续时间映射,而不是直接作用于序列。从连续时间的角度来看,在方程(1a)中,矩阵(𝐴,𝐵)不是直接学习的,而是从基础参数(̊𝐴,̊𝐵)以及参数化步长Δ中生成的。“连续参数”(Δ̊𝐴,̊𝐵)转换为“离散参数”(𝐴𝐵)通过固定公式𝐴=𝑓𝐴(Δ,̊𝐴)和𝐵=𝑓𝐵(Δ,̊𝐵),在两人(𝑓𝐴,𝑓𝐵)被称为离散化规则

SSM最初是个连续时间模型,但是可以通过离散化规则,例如零阶保持规则转化为序列模型。

备注1。虽然我们的主要模型采用了与之前工作相同的参数化和离散化步骤,但为了简化阐述和符号,我们在本文的其余部分中省略了它。我们注意到,之前关于结构化ssm的工作将连续参数(̊𝐴,̊𝐵)和离散参数(𝐴,𝐵)称为(𝐴,𝐵)和(̄𝐴,̄𝐵);我们更改了符号以简化表示,并直接关注支配主要SSM递归的离散参数。’

直接专注于离散化参数 A ˉ \bar{A} Aˉ,并且记为 A A A

递归模型。方程(1)和(2)采用递归形式,其输入为线性𝑥。因此,结构化ssm可以被视为递归神经网络(rnn)的类型,其中的线性赋予了它们额外的属性,并允许它们避免传统rnn的顺序计算。反过来,尽管进行了这种简化,ssm仍然可以完全表达为序列转换(在普遍近似的意义上)。

卷积模型。当SSM的动力学随时间不变,如方程(1)所示,该模型称为线性时不变(LTI)。在这种情况下,它们等同于卷积。因此,SSM也可以被视为cnn的类型,但其中(i)卷积核通过SSM参数(𝐴,𝐵,𝐶)隐式参数化,以及(ii)卷积核通常是全局的而不是局部的。相反,通过经典的信号处理理论,所有表现良好的卷积都可以表示为ssm。

如果(1)中参数时不变,则等同于卷积。

通常,以前的LTI ssm将使用卷积模式进行高效的可并行化训练(其中整个输入序列提前看到),并切换到递归模式(1)以进行高效的自回归推理(其中输入一次只看到一步)。

选择性状态空间模型。在Mamba中,参数(𝐴,𝐵,𝐶)也可以随时间变化的形式(2)被引入作为选择性SSM。与标准的LTI公式(1)相比,该模型可以在每个时间步上选择性地选择关注或忽略输入。它被证明在信息密集的数据(如语言)上的表现比LTI ssm好得多,特别是当它的状态大小N增加时,允许更多的信息容量。然而,它只能以递归模式而不是卷积模式进行计算,并且需要仔细的硬件感知实现来提高效率。尽管如此,它仍然低于硬件友好的模型(如cnn和transformer),因为它没有利用矩阵乘法单元,而现代加速器(如gpu和tpu)专门用于矩阵乘法单元。

选择性SSM与LTI SMM相比,可以在每个时间步上选择性地选择关注或忽略输入。效果更好,但是只能以递归模式而不是卷积模式进行计算。尽管可以用硬件感知实现来提高效率,但由于不能利用gpu和tpu进行矩阵乘法加速,所以效率没有cnn和transformer高。

结构化ssm作为序列转换

在这里插入图片描述
序列转换(如ssm或自注意力)是深度序列模型的基石,它们被纳入神经网络架构(如transformer)。(1)或(2)中的SSM是一个P = 1的序列变换;它可以推广为 P > 1 P\gt1 P>1,通过简单地跨此维度广播(换句话说,将输入视为P个独立序列,并对每个序列应用SSM)。我们可以把P看作头部维度,我们将在第7节详细讨论。

序列转换的定义
P P P作为SSM的头维度,等于是有 P P P个独立SSM。

在这里插入图片描述
在ssm中,N维是一个称为状态大小或状态维度的自由参数。我们也称它为状态扩展因子,因为它将输入/输出的大小扩展为𝑁的因子,这对这些模型的计算效率有影响。
许多类型的序列变换,如注意力,可以表示为序列维度上的单个矩阵乘法。

SSM操作的定义

在这里插入图片描述

矩阵转换的定义

2.2 Attention

注意力广义上指的是一种计算类型,它为序列中的每一对位置分配分数,使每个元素“关注”其余的位置。到目前为止,注意力最常见和最重要的变体是softmax自注意力,它可以定义为
Y = softmax ⁡ ( Q K ⊤ ) ⋅ V Y=\operatorname{softmax}\left(Q K^{\top}\right) \cdot V Y=softmax(QK)V
两两比较的机制(由物化𝑄𝐾⊤引起)导致了特征的二次注意力训练成本。
已经提出了许多注意力的变体,但所有变体都共享这些注意力分数的基本核心,具有各种近似。这项工作最重要的变体是线性注意力(Katharopoulos et al. 2020)。粗略地说,这类方法通过将softmax折叠到内核特征映射中来丢弃它,并使用矩阵乘法的结合性来重写(𝑄𝐾⊤)·𝑉=𝑄·(𝐾⊤𝑉)。此外,在因果(自回归)注意的重要情况下,他们表明,当因果掩模合并到左手边作为(𝐿◦𝑄𝐾⊤)·𝑉,其中𝐿是下三角1的矩阵,那么右手边可以作为递归展开。一些最近的和并发的工作,如RetNet (Y. Sun等人2023)和GateLoop (Katsch 2023),将其加强为更一般的形式𝐿(第10节)。本文提出的结构化掩码注意力的提法将强有力地概括这些想法。

2.3 结构化矩阵

一般矩阵𝑀∈R(T,T)需要T2参数来表示和𝑂(T2)时间来执行矩阵-向量乘法等基本操作。结构化矩阵就是那些
(i)可以通过压缩表示以次二次(理想情况下是线性)参数表示
(ii)通过直接对这种压缩表示进行操作,具有快速算法(最重要的是矩阵乘法)。

结构化矩阵的定义

也许最典型的结构化矩阵族是稀疏和低秩矩阵。然而,还有许多其他矩阵族,如Toeplitz、Cauchy、Vandermonde和butterfly矩阵,它们都被用于机器学习的高效模型。结构化矩阵是高效表示和算法的强大抽象。本文表明,SSM等价于另一类以前未在深度学习中使用的结构化矩阵,并利用这种联系来导出有效的方法和算法

2.4 概述:结构化状态空间对偶性

虽然本文开发了ssm、注意力和结构化矩阵之间更丰富的联系框架,但对主要方法进行了简要总结,该方法实际上是非常完备的,在算法上很简单。

递归(线性)形式:状态空间对偶(SSD)层可以定义为选择性SSM(2)的一个特殊情况,可以应用递归(或并行扫描)的标准计算,它在序列长度上具有线性复杂度。与Mamba中使用的版本相比,SSD有两个小的区别:

  • 将𝐴上的对角线结构进一步简化为标量乘单位阵结构。在这种情况下,每个𝐴𝑡也可以用一个标量来标识。
  • 与Mamba中使用的P = 1相比,我们使用更大的头部尺寸P。通常选择P ={64, 128},这类似于现代transformer的约定。

与原始的选择性SSM相比,这些变化可以被看作是表达能力的轻微下降,而训练效率的显著提高。特别是,我们的新算法将允许在现代加速器上使用矩阵乘法单元。

对偶(二次)形式:SSD的对偶形式是一种与注意力密切相关的二次计算,定义为
( L ∘ Q K ⊤ ) ⋅ V L i j = { a i × ⋯ × a j + 1 i ≥ j 0 i < j \left(L \circ Q K^{\top}\right) \cdot V \quad L_{i j}= \begin{cases}a_i \times \cdots \times a_{j+1} & i \geq j \\ 0 & i<j\end{cases} (LQK)VLij={ai××aj+10iji<j
与标准的softmax attention相比,有两个主要区别

  • 放弃softmax。
  • 注意力矩阵乘以一个额外的掩码矩阵𝐿。

这两种变化都可以被视为解决vanilla attention的问题。例如,最近观察到softmax会导致注意力分数的问题,如“注意力汇聚”现象。更重要的是,掩码矩阵𝐿可以被视为用不同的数据依赖的位置掩码取代transformer的启发式位置嵌入,该掩码控制跨时间传输的信息数量。

更广泛地说,这种形式是线性注意力的结构化掩码注意力泛化的一个实例,在第4节中定义。

矩阵形式与SSD算法。SSD的各种形式通过统一的矩阵表示连接起来,通过显示ssm有一个矩阵变换形式𝑌=𝑀𝑋对于一个矩阵𝑀𝜃∈R(T,T),它依赖于𝜃=(𝐴,𝐵,𝐶)。特别是,SSD的对偶形式等价于矩阵𝑀的朴素(二次时间)乘法,而递归形式是一种特殊的高效(线性时间)算法,它利用了𝑀中的结构。
除此之外,任何𝑀乘法算法都可以应用。我们提出的硬件高效的SSD算法(第6节)是一种新的结构化矩阵乘法方法,涉及𝑀的块分解,它获得了比纯线性或二次形式更好的效率权衡。与一般的选择性ssm (Gu和Dao 2023)相比,它相对简单且易于实现;清单1用几行代码提供了一个完整的实现。
图1提供了本文所述概念之间关系的简单路线图。
在这里插入图片描述

2.5 符号

在本文中,我们更喜欢使用可以映射到代码的精确符号
矩阵和向量
小写表示向量(即具有单个轴的张量)
大写表示矩阵(即具有多个轴的张量)
在这项工作中,我们不加粗矩阵。有时,如果一个矩阵沿一个轴排列或重复(因此也可以看作是一个向量),我们可以使用大写或小写。·表示标量或矩阵乘法,◦表示Hadamard (elementwise)乘法。
索引
在这里插入图片描述
维度
在这里插入图片描述
张量收缩
我们将在很大程度上依赖张量收缩或einsum表示法,既为了清晰,又作为陈述和证明结果的中心工具。我们假设读者熟悉这种表示法,它通常在现代张量库(如numpy)中使用。例如,我们可以使用contract(MN, NK→MK)来表示矩阵-矩阵乘法运算符,在我们的符号中contract(MN, NK→MK)(𝑋,𝑌)(相当于𝑋·𝑌)可以转换为numpy代码。einsum(’ mn, nk→mk ', X, Y)。

符号表
在这里插入图片描述

3.状态空间模型是结构化矩阵

本节探讨了作为序列转换的状态空间模型的不同视角,并概述了这种映射的属性和算法。本节的主要结果是关于状态空间模型和一类称为半可分矩阵的结构矩阵之间的等价性,这意味着新的效率结果(定理3.5和3.7)。

3.1SSM的矩阵转换形式

回想一下,我们对SSM的定义是通过(2)定义的参数化映射。我们的理论框架从简单地将这个变换写成映射向量的矩阵乘法开始…
通过数学归纳法等可以得到
y t = ∑ s = 0 t C t ⊤ A t : s × B s x s y = SSM ⁡ ( A , B , C ) ( x ) = M x M j i : = C j ⊤ A j ⋯ A i + 1 B i (3) \begin{aligned} y_t & =\sum_{s=0}^t C_t^{\top} A_{t: s}^{\times} B_s x_s \\ y & =\operatorname{SSM}(A, B, C)(x)=M x \\ M_{j i} & :=C_j^{\top} A_j \cdots A_{i+1} B_i\end{aligned}\tag{3} ytyMji=s=0tCtAt:s×Bsxs=SSM(A,B,C)(x)=Mx:=CjAjAi+1Bi(3)

3.2半可分矩阵

𝑀式(3)是一类称为半可分矩阵的矩阵的特殊表示。半可分矩阵是一种基本的矩阵结构。首先定义这些矩阵及其性质。
定义3.1:一个(下三角)矩阵𝑀是N-半可分的,如果包含在下三角部分(即对角线上或以下)的每个子矩阵的秩最多为N。我们称N为半可分矩阵的秩。

定义3.1,与之相关的其他形式的“可分离”结构(如拟可分矩阵和其他半可分矩阵的定义)有时被称为结构化秩矩阵(或秩结构矩阵),因为它们的子矩阵具有秩条件。半可分矩阵有许多结构化表示,包括层次半可分(HSS)、顺序半可分(SSS)和Bruhat形式(Pernet和Storjohann 2018)。我们将主要使用SSS形式。

3.2.1 顺序半可分(SSS)表示

定义3.2:一个下三角矩阵𝑀∈R(T,T)如果可以写成这样的形式
M j i = C j ⊤ A j ⋯ A i + 1 B i (4) M_{j i}=C_j^{\top} A_j \cdots A_{i+1} B_i \tag{4} Mji=CjAjAi+1Bi(4)则具有N-顺序半可分(SSS)表示。We define the operator SSS so that 𝑀 = SSS(𝐴0:T, 𝐵0:T, 𝐶0:T).
半可分矩阵的一个基本结果是,它们完全等价于具有SSS表示的矩阵。一个方向可以用一个简单的构造性证明来推导。
引理3.3:表示(4)的N-SSS矩阵𝑀是N-半可分的。
这个引理容易被证明。
方程(5)将广泛用于推导序列模型的快速算法。另一个方向在有关半可分矩阵的文献中得到了很好的证明。
命题3.4。每个n-半可分矩阵都有一个N-SSS表示。

此外,请注意,尽管定义3.2涉及用于表示的𝑂(N2T)参数(特别是存储𝐴矩阵),但它实际上可以压缩为𝑂(NT)参数,这是渐进紧密的(Pernet、Signargout和Villard 2023)。因此,在本文的其余部分,我们将合并结构化矩阵类(定义3.1)和它的特定表示(定义3.2);我们将始终使用这种表示而不是其他候选者。反过来,我们将使用N-SS来指代SSS形式的n-半可分矩阵。

半可分矩阵是一种基本的矩阵结构,具有许多重要的性质。它们与递归密切相关,可以通过多种特征(如定义3.1和3.2)定义,这些特征揭示了它们的不同联系和高效算法。我们会在附录c - 1中提到它们的其他一些属性。

N-半可分矩阵,有很多表示,如HSS表示,SSS表示等。但是都是等价的。作者证明了N-半可分矩阵和SSS的等价性。声明后面就只用SSS表示了。
另外提到SSS定义虽然用到了参数量是 N 2 T N^2T N2T的,但实际上只需要 N T NT NT

备注2。半可分性的概念非常广泛,在文献中出现了许多相似但微妙不同的定义;我们的定义可能与其他约定略有不同。首先,由于本文主要关注因果或自回归设置,我们将半可分性的定义限制在三角情况下;定义3.1更正式的说法可能是(N, 0)-半可分性。有些作者也可能把它称为一种准可分性(Eidelman和Gohberg 1999;Pernet 2016)。请参阅Vandebril等人(2005)的简要综述。

3.2.2 1-半可分矩阵:标量SSM递归

我们将挑出1-SS矩阵的特殊情况。注意,在本例中,𝐶𝑗和𝐵𝑖是标量,可以从SSS表示中提取出来(4)(我们还使用小写来强调在本例中参数是标量)
SSS ⁡ ( a , b , c ) = diag ⁡ ( c ) ⋅ M ⋅ diag ⁡ ( b ) \operatorname{SSS}(a, b, c)=\operatorname{diag}(c) \cdot M \cdot \operatorname{diag}(b) \quad SSS(a,b,c)=diag(c)Mdiag(b) where M j i = a j : i × \quad M_{j i}=a_{j: i}^{\times} Mji=aj:i×.

1-SS,也就是1-SSS,所以C,B都是标量。

由于对角矩阵很容易处理(例如,与对角矩阵相乘与elementwise标量乘法相同),我们可以忽略这些项。因此,我们对1-SS矩阵的基本表示是𝑀𝑗𝑖=𝑎𝑗:𝑖或
M = 1SS ⁡ ( a 0 : T ) : = [ 1 a 1 1 a 2 a 1 a 2 1 ⋮ ⋮ ⋱ ⋱ a T − 1 … a 1 a T − 1 … a 2 … a T − 1 1 ] M=\operatorname{1SS}\left(a_{0: T}\right):=\left[\begin{array}{ccccc}1 & & & & \\ a_1 & 1 & & & \\ a_2 a_1 & a_2 & 1 & & \\ \vdots & \vdots & \ddots & \ddots & \\ a_{T-1} \ldots a_1 & a_{T-1} \ldots a_2 & \ldots & a_{T-1} & 1\end{array}\right] M=1SS(a0:T):= 1a1a2a1aT1a11a2aT1a21aT11

1- ss矩阵的重要性在于它们与标量递归的最小形式等价——状态维为N = 1且没有(𝐵,𝐶)投影的退化SSM的情况。注意到乘法𝑦=𝑀𝑥可以通过递归计算
y t = a t : 0 x 0 + ⋯ + a t : t x t = a t ( a t − 1 : 0 x 0 + ⋯ + a t − 1 : t − 1 x t − 1 ) + a t : t x t = a t y t − 1 + x t . (7) \begin{aligned} y_t & =a_{t: 0} x_0+\cdots+a_{t: t} x_t \\ & =a_t\left(a_{t-1: 0} x_0+\cdots+a_{t-1: t-1} x_{t-1}\right)+a_{t: t} x_t \\ & =a_t y_{t-1}+x_t .\end{aligned}\tag{7} yt=at:0x0++at:txt=at(at1:0x0++at1:t1xt1)+at:txt=atyt1+xt.(7)

因此,我们也将1-SS矩阵的矩阵乘法称为标量SSM递归cumprodsum(累积积和;累积积和累积和运算符的推广)。作为递归的基本形式,与1-SS矩阵相乘是我们主要算法的重要组成部分。

本文强调,本文的中心主题之一是,序列模型上的许多算法可以简化为结构化矩阵乘法算法。1-SS矩阵就是这种联系的一个例子:有许多计算基本标量递归或cumprodsum算子的快速算法,它们都等价于1-SS矩阵的不同结构的因子分解。我们将这些算法的附录B用于1-SS矩阵乘法。

这一部分就是在说,如果SSM对应的矩阵用1-SS矩阵,那么在没有 B B B C C C,y的计算可以用标量递归。

3.3 状态空间模型是半可分矩阵

回想一下,我们对SSM的定义是通过定义2.1的参数化映射。SSM和半可分矩阵之间的联系源于简单地将这个变换写成矩阵乘法映射到向量𝑥↦→𝑦∈RT。
式(3)直接建立了状态空间模型与顺序半可分表示之间的联系,顺序半可分表示在一般情况下等价于半可分矩阵(引理3.3和命题3.4)。

SSM是一个参数化映射->可以写成矩阵形式->对应的矩阵是顺序半可分矩阵->顺序半可分就是半可分矩阵。这里的箭头都是等价箭头

定理3.5。状态空间模型转换𝑦=SSM(𝐴、𝐵𝐶)(𝑥),with状态大小N,是与”矩阵乘法:乘以一个顺序半可分表示的N-半可分矩阵𝑦= SSS(𝐴、𝐵𝐶)·𝑥“相同的。

换句话说,序列变换运算符SSM(定义2.2)与矩阵构造运算符SSS(定义3.2)重合,我们可以互换使用它们(有时简称SS)。此外-命运的转折-结构化状态空间模型和顺序半可分矩阵具有相同的缩写,强调了它们的等价性!我们可以方便地使用这些缩写SSM(状态空间模型或半可分矩阵),SSS(结构化状态空间或顺序半可分),或SS(状态空间或半可分)来无歧义地指代这两个概念。然而,我们通常使用这样的约定:SSM指状态空间模型,SS指半可分,SSS指顺序半可分。

图2说明了状态空间模型作为半可分矩阵的序列变换视角。

在这里插入图片描述
图2:(状态空间模型是半可分离矩阵。)作为序列变换,状态空间模型可以表示为作用于序列维数T的矩阵变换M∈R(T,T),为头部(左)中的每个通道共享相同的矩阵。这个矩阵是一个半可分离的矩阵(右),它是一个秩结构矩阵,其中对角线上和下方的每个子矩阵(蓝色)的秩最多为 N,等于 SSM 的状态维度。

3.4 用结构化矩阵算法计算状态空间模型

定理 3.5 很重要的原因是它将使我们能够将 SSM(和其他序列模型)的有效计算问题简化为结构化矩阵乘法的有效算法。我们简要概述了并将我们的主要新算法推迟到第 6 节,然后在第 4 节和第 5 节中展示了 SSM 与其他序列模型的等价性。
如前所述,半可分矩阵(即秩结构矩阵)是一种经典的结构化矩阵:
(i)它们具有压缩表示,如SSS形式,其中只有𝑂(T)而不是𝑂(T2)参数。
(ii)它们具有直接在压缩表示上操作的快速算法。

此外,参数化和矩阵乘法的开销可以是半可分的。
命题3.6:一个大小为 T T T的N-SS矩阵可以用𝑂(NT)参数表示,并具有时间和空间上都是𝑂(NT)的矩阵-向量乘法。

这个命题已经被Pernet, Signargout和Villard于2023年证明

例如,1-SS矩阵说明了这种联系的本质。矩阵𝑀= 1SS(𝑎)精确定义为T−1个参数𝑎0:T−1 =𝑎1,…,𝑎T−1,并且可以根据标量递归(7)在𝑂(T)时间内计算出来。

3.4.1 线性(递归)模式

在对角结构ssm的情况下,只需利用状态空间模型公式(2)并展开递归,就可以很容易地看到命题3.6 (S4D (Gu, Gupta, et al. 2022))。我们在(8)中提供了正式的张量收缩算法,其中维度S等于T。
Z =  contract  ( S P , S N → S P N ) ( X , B ) H = contract ⁡ ( T S N , S P N → T P N ) ( L , Z ) Y = contract ⁡ ( T N , T P N → T P ) ( C , H ) (8) \begin{aligned} Z & =\text { contract }(\mathrm{SP}, \mathrm{SN} \rightarrow \mathrm{SPN})(X, B) \\ H & =\operatorname{contract}(\mathrm{TSN}, \mathrm{SPN} \rightarrow \mathrm{TPN})(L, Z) \\ Y & =\operatorname{contract}(\mathrm{TN}, \mathrm{TPN} \rightarrow \mathrm{TP})(C, H)\end{aligned}\tag{8} ZHY= contract (SP,SNSPN)(X,B)=contract(TSN,SPNTPN)(L,Z)=contract(TN,TPNTP)(C,H)(8)

Here, 𝐿 ∈ R(T,T) is defined as 1SS(𝐴).该算法包含对应于(2)的3个步骤:
(i)通过输入矩阵𝐵(8a)对输入进行扩展𝑋
(ii)展开独立标量SSM递归(8b),和
(iii)将隐藏状态𝐻缩并到输出矩阵𝐶(8c)。
注意,我们在步骤(8b)中使用了标量ssm和1-SS矩阵之间的等价关系。

备注3。我们注意到(8)是Mamba (S6)模型的一个特殊情况。然而,一个简单的实现是缓慢的,因为大小为(T, P, N)的扩展张量𝑍和𝐻;Gu和Dao(2023)引入了一种硬件实现来避免物化这些张量。

令人惊讶的是,定理3.5和命题3.6立即暗示所有ssm具有与算法(8)相同的渐近效率。

定理3.7。序列长度为T的状态大小为N的任何状态空间模型(定义2.2)都可以在时间𝑂(TN)内计算出来(不考虑潜在的预处理)。

该结果对结构化SSM文献来说是新的。特别是,给定密集的非结构化𝐴𝑡矩阵,单独的总体表示似乎是大小𝑂(TN2)。因此定理3.7陈述了一个非平凡的结果,通过预处理步骤,即使是非结构化的SSM也可以高效地计算,其上界与下界𝑂(TN)匹配,下界由𝐵和𝐶的大小给出。

备注4。鉴于R(N,N)上的几乎所有稠密矩阵在复数域上都可对角化,定理3.7也许并不太令人惊讶,这导致了几乎所有稠密实数SSM等价于对角复数SSM的结果。这一事实解释了为什么对角SSM是结构化SSM最流行的形式。然而,定理3.7意味着所有实ssm(不仅仅是对角化的ssm)以及其他域(包括C本身)上的稠密ssm的更强的结果。

在实践中,有效计算的ssm仍然需要𝐴上的额外结构,特别是为了避免昂贵的预处理步骤(这两个步骤都有N阶额外的FLOPs,并涉及到硬件低效的操作,如奇异值分解)`。这些结构是过去结构化ssm(如S4(D)和Mamba)以及我们的新算法的重点。特别是,当在𝐴上施加稍微更强的结构时,我们将通过第6节中的SSM矩阵𝑀= SSS(𝐴,𝐵,𝐶)的块分解设计非常硬件高效的算法。

序列长度为T的状态大小为N的任何状态空间模型(定义2.2)都可以在时间𝑂(TN)内计算出来。

3.4.2 二次(朴素)模式

我们注意到,还有另一种方法来计算由我们的新矩阵观点暴露的SSM。对矩阵SSM表示(3)的简单计算涉及简单物化序列变换矩阵𝑀= SSS(𝐴,𝐵,𝐶)。这是一个(T, T)矩阵,因此这个简单的算法将按序列长度进行二次扩展。然而,当序列长度T很短时,由于常数因子和计算模式的硬件友好性(例如利用矩阵-矩阵乘法),这实际上可以比线性算法更高效。事实上,对于结构化ssm的特定情况,这看起来非常类似于二次注意力计算(第5节)。

3.4.3 总结

许多序列模型被明确地激发或定义为矩阵序列变换——最著名的是transformer,其中矩阵混合器是注意力矩阵。另一方面,rnn和ssm之前没有以这种方式描述。通过提供状态空间模型的显式矩阵变换形式,揭示了理解和使用它们的新方法。从计算的角度来看,任何计算状态空间模型前向传递的方法都可以看作是半可分矩阵上的矩阵乘法算法。半可分矩阵视角为状态空间对偶(SSD)提供了一个透镜,其中的对偶模式分别指线性时间半可分矩阵乘法和二次时间朴素矩阵乘法

4.结构化掩码注意力:基于结构化矩阵的线性注意力泛化

在本节中,我们将从基本原理重新审视线性注意力框架。本节的主要结果是基于张量收缩的线性注意力的简单证明(命题4.1),以及定义4.2中对结构化掩码注意力的广义抽象。我们注意到,本节从不同于状态空间模型的方向推导出主要的对偶结果,可以完全独立于第3节阅读。

  • 4.1节建立了注意力变体的框架,特别关注内核注意力和掩码内核注意力。
  • 4.2节提供了我们的第一个主要注意力结果,通过张量收缩的视角简单证明了线性注意力。
  • 4.3节定义了结构化掩码注意力,通过结构化矩阵对先验注意力变体的泛化。

4.1 注意力框架

4.1.1 Attention

(单头)注意力的基本形式是三个向量序列(𝑄,𝐾,𝑉)↦→Y
Q =  input  ( T , N ) K =  input  ( S , N ) V =  input  ( S , P ) G = Q K ⊤ ( T , S ) M = f ( G ) ( T , S ) Y = G V (   T , P ) (9) \begin{array}{rlr} Q & =\text { input } & (\mathrm{T}, \mathrm{N}) \\ K & =\text { input } & (\mathrm{S}, \mathrm{N}) \\ V & =\text { input } & (\mathrm{S}, \mathrm{P}) \\ G & =Q K^{\top} & (\mathrm{T}, \mathrm{S}) \\ M & =f(G) & (\mathrm{T}, \mathrm{S}) \\ Y & =G V & (\mathrm{~T}, \mathrm{P}) \end{array}\tag{9} QKVGMY= input = input = input =QK=f(G)=GV(T,N)(S,N)(S,P)(T,S)(T,S)( T,P)(9)
我们使用“形状标注”来表示张量的维度,例如𝑄∈R(T,N)。在一般形式中,S和T表示源序列和目标序列的长度,N表示特征维度,P表示头部维度。
attention最常见的变体使用softmax激活𝑓= softmax来规范化𝐺矩阵的行。

4.1.2 Self-Attention

我们关注最重要的自注意力:

  • S=T
  • 特征和头维度相同,N=P
  • 𝐾𝑄𝑉由相同的输入向量通过线性投影生成(𝑄=𝑊𝑄·𝑋,𝐾=𝑊𝐾·𝑋,𝑉=𝑊𝑉·𝑋)
    然而,我们的演示将这些选择进行了抽象,并从𝑄,𝐾,𝑉矩阵开始
    备注5。我们的重点是头部和特征维度相同的自注意力情况(即S = T和N = P),应该用作运行示例。定义了注意力的一般公式,不仅使框架捕获了交叉注意力等变体,还因为分离维度的符号(如S和T)使本节主要结果的压缩符号证明更加清晰。
    备注6。虽然注意力通常被框架为对这三个输入𝑄,𝐾,𝑉的操作,这三个输入是对称的,但(9)中的输入和输出维度表明了相反的情况。特别是,特征维度N在输出中不存在;因此,在S = T(例如,自注意力)的情况下,我们将𝑉视为主要输入,因此(9)定义一个适当的序列转换𝑉↦→𝑌(定义2.1)

4.1.3 Kernel Attention

在这里插入图片描述

将softmax函数应用于Gram矩阵𝐺的步骤可以分解为两部分:
1.对𝐺矩阵求幂。
2.在S轴上归一化𝐺矩阵。

我们现在可以忽略归一化这一项,因为它相当于简单地传入𝑉= 1并进行除法(我们会在7.3节再次讨论这个问题)。指数项可以看作是一个核变换:有一个(无限维)特征映射𝜑,exp(𝑄𝐾⊤)=𝜑(𝑄)𝜑(𝐾)⊤。通过将特征映射抽象为𝑄和𝐾本身的定义(即将𝑄,𝐾定义为转换后的版本),我们可以忽略softmax转换,并假设𝑄,𝐾是由内核特征映射任意生成的,可能N≠P。
已经提出了许多内核注意力的实例,包括:

  • 原始的线性注意力将核特征图定义为任意逐点激活函数,如𝑥↦→1 + elu(𝑥)。
  • 随机特征注意(RFA) (H. Peng et al. 2021)选择核特征图来近似softmax注意(即exp特征图),使用高斯核的随机傅里叶特征近似(Rahimi和Recht 2007)。这涉及到随机投影(即将𝑄和𝐾乘以随机投影𝑊并应用激活𝑥↦→(cos(𝑥),sin(𝑥))。
  • Performer (Choromanski等人。2021)提出了通过正正交随机特征(FAVOR+)的快速注意力。正随机特征(PRF)部分将核特征映射选择为随机投影,然后是特征映射𝑥↦→2−1/2 (exp(𝑥),exp(−𝑥))。这种选择是受启发的,以便内核元素是正值,并可证明近似于softmax注意力。[它还建议在正交方向上选择随机投影,我们没有考虑。]
  • coformer (Qin, Weixuan Sun, et al. 2022)通过余弦重加权机制增强RFA,该机制包含位置信息以强调局部性。这有效地传递𝑄𝑡,𝐾𝑡通过特征映射𝑥↦→(𝑥cos(𝜋𝑡/ 2𝑇),罪(𝜋𝑡/ 2𝑇))。
  • 线性随机注意力(Zheng、C. Wang和Kong 2022)从重要性采样的角度对RFA进行了推广,并对其进行了推广,以提供对整个softmax核(而不仅仅是exp分子)的更好估计。

其他相关的注意力变体包括Linformer (Sinong Wang et al. 2020)和Nyströformer (Xiong et al. 2021),它们都分别通过随机投影(Johnson-Lindenstrauss)和核近似(Nyström方法)使用注意力矩阵的低秩近似𝑀(因此与公式(9)兼容)。

Gram矩阵是任意k个向量之间两两的内积所组成的矩阵。参见https://zhuanlan.zhihu.com/p/105470826
Gram矩阵的softmax可以拆分为exp,但 Q K T QK^T QKT是Gram矩阵?

4.1.4 掩码(核)注意力

让𝐿是形状为(T, S)的掩码。最常见的是,在S = T的自回归自注意力情况下,𝐿可能是一个由1代表因果掩码的下三角矩阵。除了加强因果关系外,还可以应用许多其他类型的掩模——特别是各种稀疏模式,如带状、扩张或块对角线——这是通过降低密集注意力的复杂性而激发的。
掩码注意力通常用矩阵表示法表示为
y = ( L ∘ ( Q K ⊤ ) ) ⋅ V (10) y=\left(L \circ\left(Q K^{\top}\right)\right) \cdot V\tag{10} y=(L(QK))V(10)
更准确地说,使用形状注释,并将其分解为精确的计算序列:
G = Q K ⊤ ( T , S ) M = G ∘ L (   T ,   S ) Y = M V (   T , P ) (11) \begin{aligned} G & =Q K^{\top} & & (\mathrm{T}, \mathrm{S}) \\ M & =G \circ L & & (\mathrm{~T}, \mathrm{~S}) \\ Y & =M V & & (\mathrm{~T}, \mathrm{P}) \end{aligned}\tag{11} GMY=QK=GL=MV(T,S)( T, S)( T,P)(11)
我们在本节中改进的注意力变体推导,首先注意到这个公式可以写成一个收缩式:
Y =  contract  ( T N , S N , S P , T S → T P ) ( Q , K , V , L ) (12) Y=\text { contract }(\mathrm{TN}, \mathrm{SN}, \mathrm{SP}, \mathrm{TS} \rightarrow \mathrm{TP})(Q, K, V, L)\tag{12} Y= contract (TN,SN,SP,TSTP)(Q,K,V,L)(12)
(11)中的算法可以重构为通过成对收缩的特定顺序计算(12)
G = contract ⁡ ( T N , S N → T S ) ( Q , K ) M =  contract  ( T S , T S → T S ) ( G , L ) Y =  contract  ( T S , S P → T P ) ( M , V ) (13) \begin{aligned} & G=\operatorname{contract}(\mathrm{TN}, \mathrm{SN} \rightarrow \mathrm{TS})(Q, K) \\ & M=\text { contract }(\mathrm{TS}, \mathrm{TS} \rightarrow \mathrm{TS})(G, L) \\ & Y=\text { contract }(\mathrm{TS}, \mathrm{SP} \rightarrow \mathrm{TP})(M, V) \\ & \end{aligned}\tag{13} G=contract(TN,SNTS)(Q,K)M= contract (TS,TSTS)(G,L)Y= contract (TS,SPTP)(M,V)(13)

4.2 线性注意力

线性注意力和许多其他有效注意力的变体,通常是由改变核心注意力计算中的矩阵结合顺序(𝑄𝐾⊤)𝑉=𝑄(𝐾⊤𝑉)所激发的。然而,当添加掩码时,推导有些不那么直接(例如,原始论文(Katharopoulos et al. 2020)和变体(Y. Sun et al. 2023)声明了没有证明的公式)。
粗略地说,线性注意力方法声称下式等价于(10),这必须通过仔细展开和和跟踪指标来验证。
Y = Q ⋅ cumsum ⁡ ( K ⊤ V ) (14) Y=Q \cdot \operatorname{cumsum}\left(K^{\top} V\right)\tag{14} Y=Qcumsum(KV)(14)
命题4.1 (Katharopoulos et al. 2020)。自回归核注意力,即带有因果掩码的掩码核注意力,可以在𝑂(𝑇)时间内计算,通过每步采取常数时间的递归。

这个命题是被Katharopoulos声明的公式,但是论文中并没有被证明。

4.2.1 线性注意力的张量收缩证明

本文提出一个简单而严格的线性注意力推导,也将立即揭示如何对其进行泛化。其主要思想是以另一种顺序执行收缩(12)。我们避免模糊的矩阵表示法,直接使用收缩表示法:
Z =  contract  ( S P , S N → S P N ) ( V , K ) H =  contract(TS, SPN  → TPN ⁡ ) ( L , Z ) Y =  contract  ( T N , T P N → T P ) ( Q , H ) (15) \begin{aligned} & Z=\text { contract }(\mathrm{SP}, \mathrm{SN} \rightarrow \mathrm{SPN})(V, K) \\ & H=\text { contract(TS, SPN } \rightarrow \operatorname{TPN})(L, Z) \\ & Y=\text { contract }(\mathrm{TN}, \mathrm{TPN} \rightarrow \mathrm{TP})(Q, H) \end{aligned}\tag{15} Z= contract (SP,SNSPN)(V,K)H= contract(TS, SPN TPN)(L,Z)Y= contract (TN,TPNTP)(Q,H)(15)
直观地说,我们将这种收缩顺序解释如下。
第一步(15a)执行“扩展”到更多的特征,特征维度n的倍数。第三步(15c)收缩扩展后的特征维度。如果将𝐾视为输入(备注6),则𝑉和𝑄分别执行扩展和收缩

改变顺序,之前掩码线性注意力是Q为输入,K,V执行收缩。

第二步是最关键的,它解释了线性注意的线性部分。首先请注意(15b)只是一个直接与𝐿的矩阵乘法(因为(P, N)轴可以被展平)。还要注意,这是唯一涉及到S轴和S轴的项,因此应该具有Ω(TS)复杂性(即序列长度的二次)。然而,当掩码𝐿是标准的因果注意掩码(下三角1)时,𝐿的矩阵-向量乘法与逐特征累计和等价
y = [ 1 ⋮ ⋱ 1 … 1 ] x ⟺ y 0 = x 0 y t = y t − 1 + x t . y=\left[\begin{array}{ccc} 1 & & \\ \vdots & \ddots & \\ 1 & \ldots & 1 \end{array}\right] x \quad \Longleftrightarrow \quad \begin{aligned} & y_0=x_0 \\ & y_t=y_{t-1}+x_t \end{aligned} . y= 111 xy0=x0yt=yt1+xt.

P,N当它们是固定的,所以是Ω(TS)。
感觉说来说去就是一件事,下三角矩阵等价于递归形式。所以每次递归计算都是常数的,最终整个计算关于序列T是线性的。
类似的,稀疏矩阵当然应该会有优化算法?

4.3 结构化掩码注意力

从掩码注意力的张量收缩视角(15),我们可以立即看到,原始线性注意力的关键是因果掩码的矩阵向量乘法等价于累积和运算符。
然而,我们观察到没有理由注意力掩码必须都是1。线性注意力快速所必需的是𝐿是一个结构化矩阵,根据定义,它具有快速矩阵乘法(2.3节)。特别是,我们可以使用任何掩码矩阵𝐿,它具有次二次(理想情况下是线性)矩阵-向量乘法,通过加速瓶颈方程(15b),它将具有与标准线性注意力相同的复杂性。

这段话说出来,作者其实就默认了前面说的掩码注意力是那种下三角全1的注意力。所以作者在C.2才会说将结构化(掩码)注意力定义为掩码注意力的广泛泛化。
另外,这里想吐槽一下,一会结构化掩码注意力,一会结构化注意力。看起来确实很累。定义前前后后名称一会略写,一会不略写。

定义4.2结构化掩码注意力(SMA)(或简称结构化注意力)被定义为Q,K,V和任何结构化矩阵L(即具有次二次矩阵乘法)上的一个函数,通过四向张量收缩:
Y = contract ⁡ ( T N , S N , S P , T S → T P ) ( Q , K , V , L ) .  Y=\operatorname{contract}(\mathrm{TN}, \mathrm{SN}, \mathrm{SP}, \mathrm{TS} \rightarrow \mathrm{TP})(Q, K, V, L) \text {. } Y=contract(TN,SN,SP,TSTP)(Q,K,V,L)
SMA二次模式算法是由(13)定义的成对收缩序列,对应于标准(掩码)注意力计算
SMA线性模式算法是由(15)定义的成对收缩序列,其中步骤(15b)通过次二次结构化矩阵乘法进行优化
我们可以将结构化掩码注意力实例化到任何给定的矩阵结构类。一些例子包括(图3):
(Structured Masked Attention
图3:(结构化的掩码注意力。)SMA构造了一个掩码注意力矩阵𝑀=𝑄𝐾⊤◦𝐿为任何结构化矩阵𝐿,这定义了矩阵序列变换𝑌=𝑀𝑉。SMA的所有实例都具有对偶次二次形式,这是由不同的收缩顺序和𝐿的高效结构化矩阵乘法引起的。之前的例子包括线性注意力(Katharopoulos et al. 2020)和RetNet (Y. Sun et al. 2023)。除了本文的重点SSD (1-semiseparable SMA)之外,还有许多其他结构化注意力的潜在实例是可能的。

  • 线性注意力使用因果掩模。
  • RetNet (Y. Sun et al. 2023)对某些衰减因子𝛾∈[0,1]使用衰减掩膜𝐿𝑖𝑗=𝛾𝑖−𝑗·I[𝑗≥𝑖]
  • 衰减掩码可以泛化为Toeplitz矩阵𝐿𝑖𝑗=𝛼𝑖−𝑗用于一些可学习(或依赖输入)的参数集𝛼∈rt。这可以被解释为一种相对位置编码的形式,使人想起其他方法,如不证明(Press, N. Smith和Lewis 2022),但具有乘法性而不是加法性。
  • 另一种变体可以使用傅里叶矩阵𝐿𝑖𝑗=𝜔𝑖𝑗/T以不同的方式编码位置结构。

在第5节中,我们将考虑半可分离的SMA,它定义了我们的主要SSD模型。

4.3.1 总结:掩码注意力的双重形式

标准(掩码核)注意力经常在函数和算法之间混淆。将这种区别分开,为理解不同的注意力变体提供了一种清晰的方法。

  • 我们将被掩码注意力视为一个特定的函数(12)。
  • 标准的二次注意力计算(13)可以看作是计算该函数的算法。
  • 线性注意力(15)是计算该函数的另一种算法。

此外,在这种情况下

  • 掩码注意力函数只是四项的特定收缩。
  • 二次注意力和线性注意力算法只是执行收缩的两种不同顺序。
    众所周知,收缩顺序会在计算复杂度上产生很大的差异,从而导致二次分裂和线性分裂。就像状态空间模型是一种转换,可以以多种方式计算,有对偶的二次形式和线性形式(第3.4节),线性注意力也有类似的对偶性,由两个收缩顺序产生。事实上,这些都是对同一个潜在的二元性的不同观点,我们会在第5节中详细说明。

5.状态空间对偶

在第3节和第4节中,我们定义了结构化状态空间模型和结构化注意力,讨论了它们的属性,并表明它们都有一个二次算法和一个线性算法。本节将它们联系起来。主要结果表明,结构化状态空间模型的特定情况与结构化注意力的特定情况相吻合,线性时间SSM算法和二次时间核注意力算法是彼此的对偶形式。

  • 5.1节专门介绍标量结构的状态空间模型,其中简单的二次计算可以看作是核注意力的一个实例。
  • 第5.2节专门研究半可分SMA的结构化掩码注意力,其特征是具有高效的自回归的掩码注意力。
  • 第5.3节总结了结构化掩码注意力和结构化状态空间模型之间的联系,称为结构化状态空间对偶。

5.1 标量-单位阵结构状态空间模型

在第3节中,我们证明了状态空间模型等价于半可分离矩阵变换,从而得到线性递归形式和二次朴素形式。

回想一下,一类是由𝑦=SSM(𝐴,𝐵, 𝐶)(𝑥),和矩阵形式的SSM使用SSS(顺序半可分)表示𝑀= SSS(𝐴,𝐵,𝐶)其中𝑀𝑗𝑖=𝐶⊤𝑗𝐴𝑗:𝑖𝐵𝑖(方程(3))。

现在让我们考虑 A j A_j Aj只是一个标量的情况;换句话说,一个结构化SSM的实例,其中𝐴矩阵极端结构化:𝐴=𝑎𝐼用于标量𝑎和单位矩阵𝐼。然后我们可以重新安排
M j i = A j : i ⋅ ( C j ⊤ B i ) M_{j i}=A_{j: i} \cdot\left(C_j^{\top} B_i\right) Mji=Aj:i(CjBi)

单位阵当然满足交换律,所以可以提出来。

这可以被向量化为
L : = 1 S S ( a ) M = L ∘ ( C B ⊤ ) \begin{aligned} L & :=1 \mathrm{SS}(a) \\ M & =L \circ\left(C B^{\top}\right) \end{aligned} LM:=1SS(a)=L(CB)
其中, B , C ∈ R ( T , N ) B, C \in \mathbb{R}^{(\mathrm{T}, \mathrm{N})} B,CR(T,N)

这里是为了凑配成 Q K T QK^T QKT的样子所以写成C和B这样的。
M的实际维度是 T P ∗ T P TP*TP TPTP,所以 C C C B B B的实际维度是 T P ∗ N TP*N TPN,这里是把第一个维度看成个整体。也就是 C C C的元素 C i j C_{ij} Cij是一对列向量。 B B B同理。
我不清楚这是否会影响后面张量收缩的推导。感觉只要不调用交换律,应该问题都不大。(因为向量乘积,即矩阵乘积,也是不能用交换律的)。

使用这个公式,完整的输出𝑌=𝑀𝑋精确地计算为
G = contract ⁡ ( T N , S N → T S ) ( C , B ) ( T , S ) M = contract ⁡ ( T S , T S → T S ) ( G , L ) ( T , S ) Y =  contract  ( T S , S P → T P ) ( M , X ) ( T , P ) \begin{aligned} & G=\operatorname{contract}(\mathrm{TN}, \mathrm{SN} \rightarrow \mathrm{TS})(C, B) \quad(\mathrm{T}, \mathrm{S}) \\ & M=\operatorname{contract}(\mathrm{TS}, \mathrm{TS} \rightarrow \mathrm{TS})(G, L) \quad(\mathrm{T}, \mathrm{S}) \\ & Y=\text { contract }(\mathrm{TS}, \mathrm{SP} \rightarrow \mathrm{TP})(M, X) \quad(\mathrm{T}, \mathrm{P}) \\ & \end{aligned} G=contract(TN,SNTS)(C,B)(T,S)M=contract(TS,TSTS)(G,L)(T,S)Y= contract (TS,SPTP)(M,X)(T,P)
其中S = t。但这与掩码内核注意力定义(13)的原始定义完全相同!
因此,如3.4节所述,通过实例化半可分矩阵𝑀并执行二次矩阵-向量乘法来简单地计算标量结构ssm与二次掩码核注意力完全相同。

标量SSM->使用SSS表示->标量单位阵提出来->写成张量收缩的形式->发现完全等价于”掩码核注意力“的张量收缩形式。只是那边是Q,K,V,这里对应是C,B,X。

5.2 1-半可分结构化掩码注意力

结构化掩码注意允许使用任何结构化掩码𝐿。当𝐿是因果掩模时,它是标准的线性注意。请注意,因果掩码是𝐿= SS(1𝑇),即1-SS掩码由定义(6)中的𝑎𝑡= 1生成。这激励了将𝐿推广到1-半可分掩码类,或1-半可分结构化掩码注意力(1-SS SMA),其中线性注意递归中的累加被更一般的递归-标量SSM扫描取代,即1-半可分矩阵乘法(第3.2.2节)。
M = SSS ⁡ ( a 0 : T ) : = [ 1 a 1 1 a 2 a 1 a 2 1 ⋮ ⋮ ⋱ ⋱ a T − 1 … a 1 a T − 1 … a 2 … a T − 1 1 ] (6) M=\operatorname{SSS}\left(a_{0: T}\right):=\left[\begin{array}{ccccc} 1 & & & & \\ a_1 & 1 & & & \\ a_2 a_1 & a_2 & 1 & & \\ \vdots & \vdots & \ddots & \ddots & \\ a_{T-1} \ldots a_1 & a_{T-1} \ldots a_2 & \ldots & a_{T-1} & 1 \end{array}\right] \tag{6} M=SSS(a0:T):= 1a1a2a1aT1a11a2aT1a21aT11 (6)

上面这张图是因果掩码
这里的意思就是说,
因果掩码是是1-SS的特例。
因果掩码对应线性注意力。1-SS掩码对应1-SS SMA。
因果掩码的算法是线性注意递归。1-SS SMA的算法是递归-标量SSM扫描(3.2.2)。

最后,考虑1-半可分SMA的最重要的原因是计算它的线性形式是对角状态空间模型的一种特殊情况。SMA的线性形式是算法(15),其中瓶颈步骤(15b)可以看作是与1-SS掩码的矩阵乘法。在第3节中,我们还写出了对角线SSM(8)的计算,其中瓶颈步骤(8b)是标量SSM递归,等价于1-SS矩阵乘法。唯一的区别是(8b)在𝐿中有一个额外的N维,因为矩阵𝐴是一个大小为N的对角线矩阵。如果𝐴的所有对角线元素都相同,那么这N维将消失,这导致了结论5.1。
在这里插入图片描述
在这里插入图片描述
推论5.1。1-SS SMA(带1-半可分结构化矩阵 L L L的掩码注意力)(15)是对角SSM(8)的特例,其中对角矩阵是标量乘以单位阵。

虽然推论5.1说1-SS SMA具有有效的递归形式,但我们也可以显示一个相反的结果,以表征哪些SMA实例具有有效的自回归。

对角矩阵的时候,L矩阵有N维,所以矩阵相乘,L相当于N片,每个都在做对应乘法。而标量-单位阵的时候,L矩阵没有N了,实际上也是对应相乘,但是L是唯一了,所以L要和Z对应相乘,L仍要乘以N遍。相当是广播了。

定理5.2。有界阶的自回归过程的结构化掩码注意力的实例化(定义4.2),其结构化掩码𝐿必定是一个半可分矩阵。
换句话说,有效的自回归注意力是通常的半可分SMA。定理5.2的证明见附录C.2。
在这里插入图片描述

不是无限阶的自回归过程,是有效的。而有效的自回归注意力,是半可分SMA。

备注7。1-半可分SMA是一种特殊的状态空间模型,一般的半可分SMA比1-SS SMA具有更强的表达能力,不能用标准的SSM描述。然而,𝐿的半可分乘法和SMA的线性形式(方程(15a))都涉及一个扩张和收缩步骤,并且可以被吸收为具有单一(更大)扩张的1-SS SMA的类似实例。

综上所述,1-半可分结构注意力是SMA最重要的情况,因为它是:
1.具有输入依赖递归的线性注意力的自然泛化。
2.一般半可分注意力的最简单情况,等价于高效的自回归注意力
3.对角状态空间模型的一种特殊情况。

5.3 结构化状态空间对偶性(SSD)

总结一下我们的结果:

  • 结构化状态空间模型(第3节)是一种通常通过线性时间递归定义的模型。然而,通过扩展矩阵形式表征其线性”序列到序列“变换,可以导出一个二次形式。
  • 注意力变体(第4节)是通过二次时间成对交互定义的模型。然而,通过将其视为四向张量收缩并按不同的顺序还原,可以推导出线性形式。
  • 每一个的自然特例-更准确地说,𝐴矩阵上具有标量-单位阵结构的状态空间模型,以及𝐿掩码上具有1-半可分结构的结构化掩码注意力-是彼此具有完全相同的线性和二次形式的对偶。

图4总结了这两种表示之间的对偶性。
一个扩展的相关工作和讨论(第10节)更详细地描述了SSD和一般SSMs / attention之间的关系。
在这里插入图片描述
图4:(结构化状态空间对偶性。)状态空间对偶描述了状态空间模型与掩码注意力之间的密切关系。(左)一般的ssm和SMA都具有线性和二次形式,在符号上有直接的类比。(右)ssm和SMA在一大类状态空间对偶模型(SSD)中相交,这些模型捕获了许多作为特殊情况的序列模型。

6. 一种面向SSD模型的硬件高效算法

开发ssm、注意力和结构化矩阵之间的SSD理论框架的好处在于使用连接来改进模型和算法。在本节中,我们将展示如何从计算结构化矩阵乘法的各种算法中推导出高效计算SSD模型的各种算法。

我们的主要计算结果是一个计算SSD模型的算法,该算法结合了线性(递归)模式和二次(注意力)模式。该算法与SSMs(序列长度线性缩放)一样高效,与attention(主要使用矩阵乘法)一样硬件友好。

定理6.1。考虑一个状态扩展因子 N N N和头维度 P = N P=N P=N的SSD模型。存在一种算法,可以在任意输入 X ∈ R ( T , P ) X\in R^{(T,P)} XR(T,P)上计算模型,只需要 O ( T N 2 ) O(TN^2) O(TN2)的训练FLOPs, O ( T N ) O(TN) O(TN)的推理FLOPs, O ( N 2 ) O(N^2) O(N2)的推理内存,其工作以矩阵乘法为主。

请注意,所有这些界限都是紧的。因为在头大小为N的情况下,状态扩展为N的状态空间模型中,总状态大小为 N 2 N^2 N2(分别得到训练和推理的下界为 O ( T N 2 ) O(TN^2) O(TN2) O ( N 2 ) O(N^2) O(N2))。此外,输入𝑋本身具有TN个元素,从而产生内存下界。

定理6.1背后的主要思想是再次将计算状态空间模型的问题视为半可分离矩阵乘法,但以一种新的方式利用其结构。我们对矩阵进行分块分解,而不是以循环模式或注意力模式计算整个矩阵。对角线块可以使用对偶注意力模式来计算,这可以通过矩阵乘法高效地完成,而非对角线块可以通过半可分矩阵的秩结构来分解,并简化为一个较小的递归。我们要强调的是,清单1提供了一个独立的SSD算法实现。与Gu和Dao(2023)的一般选择性SSM相比,这种实现要简单得多,即使在本地PyTorch中也相对高效,不需要特殊的底层内核。

首先,我们将矩阵𝑀划分为 T Q \frac{T}{Q} QT × T Q \frac{T}{Q} QT网格,其中是大小为Q × Q的子矩阵,对于某些大小为Q的块。注意,通过半可分矩阵的定义属性,非对角线块是低秩的(定义3.1)

举个例子:T=9,Q=3
阴影部分是半可分矩阵非对角线块的低秩分解。
在这里插入图片描述
从这里我们可以将问题归结为这两部分。这些也可以解释为将“块”𝑦𝑗Q:(𝑗+1)Q的输出分成两个部分:块内输入的效果𝑥𝑗Q:(𝑗+1)Q,以及块前输入的效果𝑥0:𝑗Q

6.1 对角块

对角线上的块很容易处理,因为它们只是尺寸较小的自相似问题。𝑗-th块代表计算答案SSM(𝐴𝑅,𝐵𝑅,𝐶𝑅)(𝑥𝑅)𝑅=𝑗Q范围:Q(𝑗+ 1)=(𝑗Q𝑗Q + 1。,𝑗Q + Q−1)。关键是这个块可以使用任何所需的方法进行计算。特别地,对于长度较小的块Q,使用对偶二次SMA形式可以更有效地计算该问题。此外,这些块可以并行计算。

这些子问题可以解释为:假设(块)的初始状态为0,每个块的输出是什么。换句话说,对于块𝑗,这只考虑块输入𝑥𝑗Q:(𝑗+1)Q来计算正确的输出。

这些块只作用在对应的输入块上,并得到对应的输出块

6.2 低秩块

低秩因子分解包含3项,相应的计算分为3部分。在这个因子分解中,我们将使用术语

  • [ B 0 ⊤ A 2 : 0 B 1 ⊤ A 2 : 1 B 2 ⊤ A 2 : 2 ] ⊤ \left[\begin{array}{l} B_0^{\top} A_{2: 0} \\ B_1^{\top} A_{2: 1} \\ B_2^{\top} A_{2: 2} \end{array}\right]^{\top} B0A2:0B1A2:1B2A2:2 的项称为右因子或者B-block-因子。
  • A 5 : 2 A_{5: 2} A5:2称为中间因子或A-block-因子
  • [ C 6 ⊤ A 6 : 5 C 7 ⊤ A 7 : 5 C 8 ⊤ A 8 : 5 ] \left[\begin{array}{l} C_6^{\top} A_{6: 5} \\ C_7^{\top} A_{7: 5} \\ C_8^{\top} A_{8: 5} \end{array}\right] C6A6:5C7A7:5C8A8:5 称为左因子或C-block-因子。

右边因子。这一步计算与低秩因子分解的正确乘法𝐵-block-factors。注意,对于每个块,这是一个(N, Q) by (Q, P)矩阵乘法,其中N是状态维度,𝑃是头部维度。结果是每个块的(N, P)张量,它与扩展的隐藏状态h具有相同的维度。

这可以解释为:假设(块)的初始状态为0,每个块的最终状态是什么。换句话说,假设𝑥0:𝑗Q = 0,则计算ℎ𝑗Q+Q - 1。
在这里插入图片描述

右因子是计算输入到状态的

中心因子。这一步计算中心𝐴-block-factors项在低秩分解中的效果。在前一步中,每个块的最终状态具有总形状(T/Q, N, P)。现在乘以由生成的1-SS矩阵
A 2 Q − 1 : Q − 1 × , A 3 Q − 1 : 2 Q − 1 × , … , A T − 1 : T − Q − 1 × A_{2 Q-1: Q-1}^{\times}, A_{3 \mathrm{Q}-1: 2 \mathrm{Q}-1}^{\times}, \ldots, A_{\mathrm{T}-1: \mathrm{T}-\mathrm{Q}-1}^{\times} A2Q1:Q1×,A3Q1:2Q1×,,AT1:TQ1×
这一步可以通过任何计算1-SS乘法的算法来计算(也称为标量SSM扫描或cumprodsum运算符)。
这可以解释为:考虑之前的所有输入,每个块的实际最终状态是什么;换句话说,在考虑所有𝑥0:(𝑗+1)Q的情况下,这计算了真正的隐藏状态ℎ𝑗Q。

中心因子是计算输入到状态的

左边因子。这一步计算与低秩分解左边的乘法𝐶-block-factors。对于每个块,可以用一个矩阵乘法契约(QN, NP→QP)表示。

这可以解释为:考虑正确的初始状态ℎ𝑗Q−1,并假设输入𝑥𝑗Q:(𝑗+1)Q为0,每个块的输出是什么?换句话说,对于块𝑗,这只考虑先前的输入𝑥0:𝑗Q来计算正确的输出。

左边因子是计算状态到输出的
下面是一个例子
在这里插入图片描述

6.3 计算成本

我们定义了表示法BMM(B, M, N, K)来定义批量维为B的矩阵乘法收缩(MK, KN→MN)。由这个表示法我们可以推断出3个方面的效率:

  • 计算消耗:𝑂(BMNK)FLOPs。
  • 内存消耗:总共𝑂 (B(MK + KN + MN)) 空间。
  • 并行化:更大的M、N、K项可以利用现代加速器上专门的矩阵乘法单元。

中心块。二次SMA计算的消耗分为三步(式(16)):

  • 计算代价为BMM(T/Q, Q, Q, N)的核矩阵𝐶⊤𝐵。
  • 乘以掩码矩阵,这是对形状为(T/Q, Q, Q)的张量的elementwise操作。
  • 乘以代价为BMM(T/Q, Q, P, N)的𝑋值。
    在这里插入图片描述

Q较小(不够小就继续分小),直接用二次SMA
第一个地方有笔误,应该是 C B T CB^T CBT

低秩块:右边的因子。这一步是一个代价为BMM(T/Q, N, P, Q)的矩阵乘法。

低秩块:中心因子。这一步是在(N, P)独立信道上进行长度为T/Q的标量SSM扫描(或1-SS乘法)。该扫描的工作是TNP/Q,与其他因素相比可以忽略不计。
请注意,由于分块将序列的长度从T减少到T/Q,此扫描的成本比纯SSM扫描(例如Mamba的选择性扫描)小Q倍。因此,我们观察到,在大多数问题长度上,其他算法(附录B)可能更高效或更容易实现,而不会显著减慢速度。例如,通过1- ss矩阵乘法的简单实现的代价是BMM(1, T/Q, NP, T/Q),这比简单的递归/扫描实现更容易实现,也更高效。

标量SSM

低秩块:左因子。这一步是一个代价为BMM(T/Q, Q, P, N)的矩阵乘法。

总成本。如果我们设置N = P = Q(换句话说,状态维度、头维度和块长度相等),则上述所有BMM项都变为BMM(T/N, N, N, N)。其计算特性如下:

  • 𝑂(TN2)的总失败数。
  • 总内存𝑂(TN)。
  • 主要包括形状为(N, N)的矩阵乘法。

注意,内存消耗是紧的;输入和输出𝑥,𝑦具有形状(T, P) = (T, N)。同时,flop计数反映了N的一个额外因素,它是由自回归状态大小引起的,对所有模型都是相同的。

除了matmuls,在NP = N2特征和序列长度T/Q的情况下,还存在标量SSM扫描。这花费了𝑂(T/QN2)次FLOPs和𝑂(log(T/Q))深度。虽然它没有使用矩阵乘法,但它仍然是可并行的,与其他步骤相比,所做的工作可以忽略不计;这在我们的GPU实现中成本可以忽略不计。

与纯SSM和注意力模型的比较。通过仅利用矩阵乘法,二次注意力的硬件效率也非常高,但总FLOPs为T2𝑁。它在训练和推理时较慢的计算速度可以直接视为具有较大状态大小的结果——标准注意力的状态大小随序列长度T扩展,因为它缓存其历史而不压缩其状态。

线性ssm的总flop数为TNP = TN2,与SSD相同。然而,简单的实现需要一个状态扩展(15a)来物化额外的内存,以及一个不利用矩阵乘法的标量运算(15b)

我们注意到,还有许多其他的矩阵分解是可能的(例如,通过不同的结构化矩阵分解,请参阅附录B,以获取1-SS乘法的算法概要),这可能会为ssd带来更多的算法,这些算法在其他特殊设置下可能会更好。更广泛地说,我们注意到,半可分矩阵有丰富的文献,除了我们使用的SSS形式(定义3.2)外,还有更多的表示形式(定义3.2),甚至可能有更高效的算法。

7.Mamba-2架构

通过连接ssm和attention, SSD框架允许我们为两者开发共享的词汇和技术库。在本节中,我们将讨论一些原本为Transformer开发的的思想来理解和修改SSD层的例子。我们讨论了几种设计选择,从而形成了Mamba-2架构。这些变差轴将在9.4节中消融。


图6:(Mamba-2架构。)Mamba-2块通过去除序列线性投影简化了Mamba块;SSM参数𝐴、𝐵、𝐶在块的开头产生,而不是作为SSM输入𝑋的函数。与NormFormer (Shleifer、Weston和Ott 2021)一样,增加了一个额外的规范化层,以提高稳定性。𝐵和𝐶投影只在𝑋头上共享一个头,类似于多值注意力(MVA)。

7.1块设计

我们首先讨论了对神经网络块的修改,这些修改独立于内部序列混合层(即核心SSD层之外)。
平行参数投影。Mamba-1的动机是基于SSM为中心的观点,其中选择性SSM层被视为从𝑋↦→𝑌的映射。SSM参数𝐴,𝐵,𝐶被视为附属参数,是SSM的输入𝑋的函数。因此,线性投影定义(𝐴,𝐵,𝐶)发生在初始线性投影创建𝑋之后。

在Mamba-2中,SSD层被视为从𝐴,𝑋,𝐵,𝐶↦→𝑌的地图。因此,在区块的开始,用单个投影并行地产生𝐴,𝑋,𝐵,𝐶是有意义的。请注意与标准注意力架构的类比,其中𝑋、𝐵、𝐶类似于并行创建的𝑄、𝐾、𝑉投影。

请注意,对SSM的𝐴,𝐵,𝐶,𝑋输入采用并行投影,可以略微减少参数,更重要的是,通过使用标准的Megatron分片模式,更适合于更大模型的张量并行(Shoeybi等人2019)。

额外的Normalization。在初步实验中,我们发现在较大的模型中容易出现不稳定性。我们可以通过在最终输出投影之前的块中添加额外的规范化层(例如LayerNorm、GroupNorm或RMSNorm)来缓解这一问题。规范化的使用与NormFormer架构(Shleifer、Weston和Ott 2021)最直接相关,该架构还在MLP和MHA块的末尾添加了规范化层。

我们还注意到,这种变化类似于最近与Mamba-2相关的其他模型,这些模型来自于线性注意力的观点。原始线性注意力公式通过分母项进行归一化,该项模拟了标准注意力中的softmax函数的归一化。TransNormerLLM (Qin, Dong Li, et al. 2023)和RetNet (Y. Sun et al. 2023)发现这种归一化是不稳定的,并在线性注意力层之后添加了额外的LayerNorm或GroupNorm。我们额外的归一化层与这些略有不同,它发生在乘性门分支之后而不是之前。

7.2 序列转换的多头模式

回想ssm被定义为序列变换(定义2.1),其中:

  • 𝐴,𝐵,𝐶参数具有状态维度N。
  • 它们定义了序列变换RT→RT,例如可以表示为一个矩阵𝑀∈R(T,T)。-
  • 这种变换在输入序列𝑋∈R(T,P)上运行,在P轴上独立。

我们可以把这看作是定义序列变换的一个头。
定义7.1(多头模式)。一个多头序列变换由H个独立的头组成,总模型维数D = d_model。参数可能被绑在头部上,导致头部模式

状态大小N和头部维度P分别类似于注意力的𝑄𝐾头部维度和𝑉头部维度。就像在现代Transformer架构中一样(Chowdhery et al. 2023;Touvron, Lavril等人2023),在Mamba-2中,我们通常选择这些为64或128左右的常数;当模型维数D增加时,我们在保持头部维数N和P固定的情况下增加头部的数量。为了描述如何做到这一点,我们可以转移和概括多头注意力的想法,为ssm或任何一般序列转换定义类似的模式。


多头SSM (MHS) /多头注意力(MHA)模式。经典的MHA模式假设头维度P整除模型维度D,头的数量定义为H = D/P。然后,通过创建每个参数的H个独立副本来创建核心序列变换的H个副本。请注意,虽然MHA模式最初是为注意力序列转换描述的,但它可以应用于与定义2.1兼容的任何东西。例如,多头SSD层将接受如式(17)所示形状的输入,其中SSD算法在H = n_heads维度上广播。

多收缩SSM (MCS) /多查询注意力(MQA)模式。多查询注意力(Shazeer 2019)是对注意力的一种聪明的优化,可以显著提高自回归推理的速度,这依赖于缓存𝐾和𝑉张量。这种技术简单地避免了给𝐾和𝑉额外的头部维度,或者换句话说,在所有𝑄头部中广播一个(𝐾,𝑉)头部。

利用状态空间对偶性,我们可以定义一个等价的MQA SSM版本,如式(18)。在这里,𝑋和𝐵(类似于attention的𝑉和𝐾)在H的头部共享。我们也称其为多收缩SSM (MCS)头部模式,因为控制SSM状态收缩的𝐶参数在每个头部都有独立的副本。

开始打算把Transformer中的成果拿到SSM上来了。Transformer上有K,V在头间共享的,那么SSM也可以有。K,V对应B,X,所以B,X在头间共享。

多输入SSM (MIS) /多值注意力(MVA)模式。虽然MQA因为其KV缓存而值得关注,但它并不是ssm的自然选择。相反,在Mamba中,𝑋被视为SSM的主要输入,因此𝐵和𝐶是跨输入通道共享的参数。在公式(20)中定义了一种新的多输入SSM (MIS)模式的多值注意力(MVA),它可以再次应用于任何序列变换,如SSD。

X如果是头共享的,头间输入就是个一维的。表单能力太弱了。所以SSM的自然选择当然不是X是头共享的。这里MIS模型下,C,B是头共享的。

有了这个词汇表,我们可以更精确地描述原始曼巴架构。

命题7.2。Mamba架构(Gu和Dao 2023)的选择性SSM (S6)层可以被视为具有

  • 头部尺寸𝑃= 1:每个通道都有独立的SSM动态𝐴。
  • 多输入SSM (MIS)或多值注意力(MVA)头部结构:𝐵,𝐶矩阵(对应于𝐾,𝑄在注意力二元性中)在输入的所有通道𝑋(对应于𝑉在注意力中)中共享。

S6模型对应的P=1,所以有H=D个头,且头间B和C是共享的。

当应用于SSD(9.4.3节)时,我们也可以去掉这些头部模式的变体。有趣的是,尽管在参数计数和总状态维度上进行了控制,但下游性能仍有明显的差异。通过经验发现,最初在Mamba中使用的MVA模式表现最好。

分组的头部模式。多查询注意力的思想可以扩展到分组查询注意力(Ainslie et al. 2023):可以创建G个独立的K和V头,而不是1个K和V头,其中1<G和G整除H。这是由于弥合了多查询和多头注意力之间的性能差异,并通过将G设置为分片数量的倍数来实现更有效的张量并行(第8节)。
同样,在Mamba-2中使用的多输入SSM头部模式可以很容易地扩展到分组输入SSM (GIS),或同义词分组值注意力(GVA)。泛化很简单,为简单起见,我们省略了细节。

类似地,在Mamba-2中使用的多输入SSM头模式可以很容易地扩展到分组输入SSM (GIS),或同义分组值注意力(GVA)。泛化很简单,为简单起见,我们省略了细节。

对H个头进行分组。比如3个3个一组。这样G就等于3。

7.3 线性注意力的其他SSD扩展

本文描述了一个由线性注意力驱动的SSD架构修改的例子。我们将在9.4.3节中删除这些设置,作为负面结果的一种形式,我们发现它们并没有显著提高性能,因此不能将它们作为默认设置。尽管如此,这些说明了如何结合大量关于注意力的文献来定义SSD的变体。我们将内核特征图的选择视为Mamba-2架构中的一个超参数,并期望其他受注意力启发的简单修改也成为可能。

Softmax注意力的核注意力近似。线性注意力或核注意力的许多变体是由将softmax(𝑄𝐾⊤)的注意力分数视为

  • 一个指数核𝑍= exp(𝑄𝐾⊤),对于某些内核特征图,可以近似为𝑍=𝜓(𝑄)𝜓(𝐾)⊤。
  • 通过𝑀=𝐺/𝐺11⊤对内核进行规范化,使行和为1,其中elementwise进行除法,1是所有1的向量。

指数核特征映射。在Mamba-2中,我们合并了一个灵活的内核特征映射,并将其应用于𝐵和𝐶分支(对应于𝐾和𝑉分支)。为了简单和对称,特征映射也可以选择性地应用于𝑋(𝑉)分支。这在图6中由任意非线性表示。默认情况下,我们简单地选择𝜓作为是一个逐元素的Swish / SiLU函数(Hendrycks和Gimpel 2016;Ramachandran, Zoph和Le 2017)。我们在9.4.3节的消融中探索了其他选项,包括线性注意力、Performer、随机特征注意力和coformer使用的特征图(第4.1.3节)。

**合并归一化(分母)项。**要找到分母项,我们只需计算𝑀1。但回想一下,模型的最终输出只是𝑌=𝑀𝑋(式(16))。因此,可以简单地通过使用额外的列1对𝑋进行增广来找到规范化项,从而得到形状为(T, P + 1)的张量。注意,在这种情况下,内核特征映射𝜓必须是正的,以便和是正的。

8.ssm系统优化

本文描述了对ssm的几种系统优化,特别是Mamba-2架构,以进行大规模高效的训练和推理。重点关注张量并行和序列并行,用于大规模训练,作为一个很好的变长序列,用于高效的微调和推理。

8.1 张量并行

张量并行(TP) (Shoeybi et al. 2019)是一种模型并行技术,它将每个层(如attention、MLP)划分为在多个加速器(如gpu)上运行。这种技术被广泛用于训练大多数大型模型(Brown et al. 2020;Chowdhery等人2023;Touvron, Lavril等2023;Touvron, L. Martin等人。2023)在GPU集群上,每个节点通常有4-8个GPU,具有NVLink等快速网络。TP最初是为Transformer架构开发的,将其应用于其他架构并不简单。我们首先展示了在Mamba架构下使用TP所面临的挑战,并展示了Mamba-2架构是如何设计以提高TP效率的。

回顾mamba架构,使用单个输入𝑢∈R𝐿×𝑑(没有批处理为简单起见),输入投影矩阵𝑊(𝑥)𝑊(𝑧)∈R𝑑×𝑒𝑑𝑒是扩张的因素(通常2),和输出投影矩阵𝑊(𝑜)∈R𝑒𝑑×𝑑:

使用TP,假设我们想要将计算划分到2个gpu上。很容易将输入投影矩阵𝑊(𝑥)和𝑊(𝑧)分成两个大小分别为𝑑×𝑒𝑑2的分区。那么每个GPU将存储𝑥𝑐的一半大小𝐿×𝑒𝑑2。然而,我们看到,由于Δ,𝐵,𝐶是𝑥𝑐函数,因此我们需要在gpu之间额外的all-reduce,以在计算Δ,𝐵,𝐶之前获得整个𝑥𝑐。之后,两个gpu可以并行计算SSM,因为它们沿着𝑑独立。最后,我们可以将输出投影矩阵𝑊(𝑜)分成两个大小为𝑒𝑑2 ×𝑑的分区,并在最后进行all-reduce。与transformer相比,我们会发生两次all-reduce,而不是一次,通信时间增加了一倍。对于大规模transformer训练,通信可能已经花费了相当多的时间(例如10-20%),而加倍的通信将使Mamba在大规模训练中不那么高效。

原来的mamba架构,中间因为算 Δ \Delta Δ等等,必须要占用一次all-reduce,所以通信增加一倍。

对于Mamba-2,我们的目标是每个块只有一个all-reduce,类似于transformer中的注意力或MLP块。因此,我们用一个投影,直接从𝑢而不是从𝑥𝑐得到Δ,𝐵,𝐶,允许我们分割这些投影矩阵。这意味着我们在不同的GPU上有不同的Δ,𝐵,𝐶集合,这相当于在一个更大的“逻辑GPU”上有几个Δ,𝐵,𝐶的“组”。此外,我们在每个块中使用GroupNorm,其组的数量可以被TP度整除,这样TP组中的gpu在块内就不会通信:

我们看到,我们只需要划分输入投影矩阵和输出投影矩阵,并且只需要在块的末尾进行全部归约。这类似于TP for attention和MLP层的设计。特别是,如果我们有TP度2,我们将分裂𝑊(𝑥)=(𝑊(𝑥)1,𝑊(𝑥)2]与𝑊(𝑥)𝑖∈R𝑑×𝑒𝑑/ 2,𝑊(𝑧)=(𝑊(𝑧)1,𝑊(𝑧)2]与𝑊(𝑧)𝑖∈R𝑑×𝑒𝑑/ 2,和𝑊(𝑜)= "𝑊(𝑜)1𝑊(𝑜)2 #𝑊(𝑜)𝑖∈R𝑒𝑑/ 2×𝑑。对于𝑖= 1,2,TP Mamba-2层可以写成:

我们在图7(左)中说明了张量与Mamba-2的并行。

图7:(Mamba-2块的并行性)(左:张量并行性)我们分割输入投影矩阵𝑊(𝑥),𝑊(𝑧)和输出投影矩阵𝑊(𝑜)。每个SSM头(𝐴,𝐵,𝐶,𝑋)↦→𝑌生活在单个设备上。选择GroupNorm作为最后的规范化层可以避免额外的通信。我们需要每层一个all-reduce,就像MLP或Transformer中的注意力块一样。(右:序列/上下文并行)类似于SSD算法,有多个设备,我们可以沿着序列维度进行分割。每个设备计算其序列的状态,然后将该状态传递给下一个GPU。

8.2 序列并行

对于非常长的序列,我们可能需要沿着序列长度维度将输入和激活分割到不同的gpu。主要有两种技术:
1.残差和归一化操作的序列并行(SP):由Korthikanti等人(2023)首次提出,该技术将TP中的all-reduce分解为reduce-scatter和all-gather。注意到对同一TP组中的所有gpu的相同输入重复进行残差和归一化操作,SP通过执行以下操作沿着序列长度维度拆分激活:reduce-scatter、残差和归一化,然后all-gather。
由于Mamba-2架构使用相同的残差和规范化结构,因此SP无需修改即可应用。
2.用于token混合操作(attention或SSM)的序列并行性,也称为“上下文并行性”(CP)。已经为注意力层开发了几种技术(例如,环注意力(Liu, Yan, et al. 2024;Liu、Zaharia和Abbeel 2023)),使用复杂的负载平衡技术(Brandon等人2023)。序列并行attention的困难在于,我们可以将查询和键拆分为块,但每个查询块都需要与关键块交互,导致通信带宽是worker数量的二次型。

使用SSM,我们可以以一种简单的方式分割序列:每个worker获取一个初始状态,根据他们的输入计算SSM,返回最终状态,并将该最终状态传递给下一个worker。通信带宽与工作节点的数量成线性关系。这种分解与SSD算法中的块分解完全相同(图5),将其划分为块/块。我们在图7(右)中说明了这种上下文并行。

8.3 变量长度

虽然预训练通常为批处理使用相同的序列长度,但在微调或推理期间,模型可能需要处理不同长度的不同输入序列。处理这种情况的一种简单方法是将批处理中的所有序列右补到最大长度,但如果序列的长度相差很大,这可能是低效的。对于transformer,已经开发了复杂的技术来避免填充并在gpu之间进行负载平衡(Zeng et al. 2022;Y. Zhai et al. 2023),或在同一批次中打包多个序列并调整注意力掩码(Ding et al. 2024;Pouransari et al. 2024)。特别是对于SSMs和Mamba,我们可以通过简单地将整个批处理为一个长序列来处理可变的序列长度,并避免在单个序列之间传递状态。这相当于简单地在一个序列的末尾为标记𝑡设置𝐴𝑡= 0,以防止它将信息传递给标记𝑡+ 1,该标记属于不同的序列。

9.实证验证

在对循环模型具有挑战性的合成召回任务(9.1节)以及标准语言建模预训练和下游评估(9.2节)上对Mamba-2进行了实证评估。验证了SSD算法比Mamba-1(9.3节)高效得多,并与优化后的中等序列长度的注意力相当。最后,我们讨论了Mamba-2架构中的各种设计选择(第9.4节)。

9.1合成:联想回忆

合成联想回忆任务是测试语言模型在上下文中查找信息的能力的常用任务。总的来说,它们涉及向自回归模型提供键值关联对,然后提示模型在显示之前见过的键时产生正确的补全。**多查询关联回忆(MQAR)**任务是该任务的一种特殊表述,该任务要求模型记忆多个关联(Arora, Eyuboglu, Timalsina等,2024)。原始的Mamba论文报告了相关合成任务的结果,特别是选择性复制(Gu和Dao 2023)和归纳头(Olsson et al. 2022),它们可以被视为更容易的联想回忆任务。MQAR任务也与"电话簿查找"任务密切相关,由于ssm等循环模型的有限状态容量,这些任务被证明是具有挑战性的(De等人2024;Jelassi et al. 2024)。

11.结论

本文提出了一个基于经过充分研究的结构化矩阵类的理论框架,弥合了ssm和注意力变体之间的概念差距。该框架产生了关于最近的ssm(如Mamba)在语言建模方面的表现以及transformer的见解。此外,所提出的理论工具通过连接算法和系统两端的进步,为改进ssm(和潜在的transformer)提供了新思路。作为一个示范,该框架指导我们在SSMs和结构化注意力的交叉点上设计一个新架构(Mamba-2)。

C.2 自回归掩码注意力是一种半可分离结构注意力

我们来证明5.2节中的定理5.2。在4.3节中,我们将结构化注意力定义为掩码注意力的广泛泛化,其中效率的属性(即内核注意力的线性时间形式)被抽象为结构化矩阵乘法的效率。然而,除了计算效率之外,标准线性注意力(Katharopoulos et al. 2020)还具有两个重要属性。首先,它是因果关系,这对于自回归建模等设置是必需的。此外,它具有高效的自回归生成。换句话说,一个自回归步骤的成本——即在看到𝑥𝑇后计算输出𝑦𝑇的增量成本,因为𝑥0:𝑇已经被看到并预处理了——只需要常数时间。
本文描述了哪些SMA实例具有有效的自回归。
在SMA框架中,因果关系等价于掩码𝐿是下三角矩阵的约束。
描述具有高效自回归的𝐿矩阵的空间更加困难。我们将采用狭义的自回归过程技术定义,遵循时间序列文献中的经典定义(例如ARIMA过程(Box et al. 2015))。
C.2定义。我们定义一个k阶自回归变换𝑥∈R𝑇↦→𝑦∈R𝑇,每个输出𝑦𝑡只依赖过去𝑘个输出和当前的输入:
y t = μ t x t + ℓ t 1 y t − 1 + ⋯ + ℓ t k y t − k y_t=\mu_t x_t+\ell_{t 1} y_{t-1}+\cdots+\ell_{t k} y_{t-k} yt=μtxt+t1yt1++tkytk
请注意,𝐿是累加矩阵的情况是一个特殊情况,𝑘= 1,因此𝑦𝑡=𝑥𝑡+ y t − 1 y_{t-1} yt1。在此定义下,利用半可分矩阵的性质对高效自回归线性变换空间进行了刻画。定理C.3形式化并证明了定理5.2。

这里有笔误,应该是 y t − 1 y_{t-1} yt1

定理C.3。让𝐿∈R𝑇×𝑇是一个有效的k阶自回归转换。那么𝐿是一个阶为𝑘+ 1的状态空间模型。
证明:。。。
接下来,根据半可分性的秩特征(定义3.1),𝑘+1带矩阵是𝑘+1-半可分的。根据命题C.1,𝐿的逆因此最多是𝑘+2-半可分的。由于带状矩阵的附加结构,可以得到稍强的𝑘+1的界。最后,根据定理3.5,𝐿被刻画为一个阶𝑘+1状态空间模型。
换句话说,高效的自回归注意力是半可分的SMA。

1.这里的证明要用到C.1的结论。
2.放缩,最后得到k+1半可分,等价于k+1阶状态空间模型。

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值