前言
实话说,过去一两月一直忙着我司两大类项目的推进
- 一类是正在逐一上线基于大模型的论文翻译、论文审稿、论文对话、论文修订/润色、论文idea提炼等等(截止到24年8月底,其中的审稿和翻译已上线七月官网 )
- 一类是正在抓紧做面向一个个工厂的具身智能机器人的解决方案,且很快会分别在我司在各地的办公室(南京、长沙、武汉、北京),一一摆上一两台干活的具身机器人
所以虽然说mamba2已发布一月有余,但实在是没有一块完整的时间来对其做详尽而细致的解读,而最终促使我来写的最大的动力还是来源于我半年前对mamba1的解读,实在是太受欢迎了且影响力巨大(截止到24年7月初,半年下来阅读量10万,2千余次收藏,在同样发表半年内文章中的表现很突出)
加之之前就有读者在我对上面mamba1做解读的文章下留言,什么时候出mamba2的解读,让我好几次跃跃欲试想开写
然,在我下定决心写本文之前,内心还是有过一阵小纠结的
- 一方面,怕没有一大块完整的时间(回想过去,23年上半年因为ChatGPT,公司重新焕发生机,个人也前所未有的沉迷于技术,又因23年下半年做大模型项目延续至今,今后因为业务的增长 大量的各种会议 可能难以再像过去一年半百分百沉迷于技术了)
- 二方面,mamba2的论文特别长,即《Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality》一文长达52页(这个则是两位作者写的解读blog:State Space Duality (Mamba-2) Part I - The Model),全是各种概念、公式,故为了更好的理解mamba2,建议先熟练mamba1
当然,mamba2的核心主要解决两个问题:1 打通SSM与transformer之间的联系,2 将mamba2表述为矩阵乘法以加速训练
具体而言,在结构化掩码注意力SMA中
1) 首先,可以通过掩码矩阵(比如因果掩码)来指导注意力机制——控制信息流向,从而决定哪些信息是重要的,哪些可以忽略
2) 其次,针对传统注意力机制下的计算对Q, K, V的操作——L* (QK^T)V,都可以找到一个近似Q K V的结构化的N-半可分矩阵,然后,通过与对应的N-半可分矩阵相乘,以达到加速计算的目的
总之,这种矩阵能够表示不同种类的注意力形式,相当于都能通过矩阵运算来进行不同掩码下的注意力操作
3) 而N-半可分矩阵「准确的说是1-半可分矩阵(简称1-SS矩阵)」,可以直接应用于SSM中的A矩阵(表示状态转移的矩阵)
从而通过将SMA中的结构化矩阵应用于SSM中的A矩阵,如此,便将SSM和注意力结合起来了
不过还是因为过去十多年写博客的经验,使得自己在面对再难啃的算法都有足够的自信与底气,坚信都可以一步步拆解、一步步抽丝剥茧并清晰易懂的写出来
- 读者在看本文时,也不用急,一步步来,可以慢慢看懂的,且未来一两月 我也会不断修订本文以让之不断更加通俗易懂
- 且为了解释清楚每一个定义、公式、矩阵,我会在文中不厌其烦的、不断列举大量、具体,但论文中没有的矩阵示例,以不断降低理解门槛
故本文最终还是来了
第一部分 背景回顾:从SSM、结构化矩阵到SSD的一系列定义
1.1 结构化SSM的定义:Structured State Space Model
1.1.1 离散化、循环结构表示、卷积结构表示
虽然在之前对mamba1的讲解中已经讲过了很多背景,但为本文的完整性起见,还是把一系列背景知识按照mamba2论文的思路,再度逐一梳理下
首先,结构化状态空间序列模型S4是受到的特定连续系统的启发(如下述公式1所示,是结构化SSM的一般离散形式),该系统将一维序列通过隐式潜在状态 做映射(相当于将SSM简单地写成矩阵乘法)
- 其中、均是标量,则被视为具有N维的向量,且
- 其中的 矩阵 控制时间动态,从而必须是结构化的(结构化SSM也因此得名),以便能够足够高效地计算这种序列到序列的转换,从而在深度神经网络中使用
梳理一下结构化SSM的发展历史
- 最初的结构化SSM起源于函数的连续时间映射,而不是直接对序列进行操作
在连续时间视角中,在公式(1a)中,矩阵 (𝐴, 𝐵)不是直接学习的,而是从底层参数生成的,并且伴随着一个参数化的步长 Δ
“连续参数”通过固定公式和转换为“离散参数”(𝐴, 𝐵),其中这对 (, )被称为discretization rule - 结构化 SSM 可以被视为一种递归神经网络RNN,其中线性赋予它们额外的属性,并使它们能够避免传统 RNN 的顺序计算。相反,尽管有这种简化,SSM 仍然可以完全表达为序列变换 更多详见此文《一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba》的第2.1.2节
- 当SSM的动态在时间上是恒定的,如公式(1)所示,该模型称为线性时不变(linear time-invariant,简称LTI)模型,在这种情况下,它们等同于卷积 因此,SSM也可以被视为CNN的一种类型,但卷积核通过SSM参数 (𝐴, 𝐵, 𝐶)隐式参数化,且卷积核通常是全局的而不是局部的
反过来,通过经典的信号处理理论,所有充分良好的卷积都可以表示为SSM
通常,以前的LTI SSM会
- 使用卷积模式进行高效的可并行训练(整个输入序列提前看到)
- 并切换到递归模式(如本节开头的公式1所述)进行高效的自回归推理(输入逐步看到)
1.1.2 mamba一代的问题:没法用矩阵乘法
当在 Mamba1 中被引入为选择性 SSM时,则相当于允许(A, B, C)这三个参数随时间而变化(如下面公式2所示),此时,、、
公式2与标准的 LTI 公式1相比,该模型可以在每个时间步选择性地关注或忽略输入
在信息密集型数据如语言上,它的表现被证明远优于 LTI SSM,特别是随着其状态大小 N的增加,允许更多的信息容量
然而,它只能在递归模式下计算,而不是卷积模式,并且需要“专门的硬件感知实现”才能高效,即如下图所示
即便如此,它仍然不如硬件友好的模型(如 CNN 和 Transformer)高效,因为它没有利用矩阵乘法单元,而现代加速器(如 GPU 和 TPU)正是为此而专门设计的
总之,虽然时间不变SSM 与连续、递归和卷积序列模型密切相关,但它们与注意力机制没有直接关系。所以mamba2想揭示选择性SSM和注意力机制之间的更深层次关系,并利用这一点显著提高SSM的训练速度,同时允许更大的状态规模N
1.1.3 结构化SSM作为序列变换:三个定义之2.1 2.2 2.3
请直接看一下三个定义(分别定义序列变换、S6和注意力机制的序列变换形式、序列变换与矩阵的联系)
- 定义 2.1 一般而言,所谓序列变换指的是序列上的参数化映射
其中,,并且𝜃是任意参数集合
表示序列或时间轴,可以作为下标索引到第一个维度,例如
序列变换(例如SSM或自注意力机制)是深度序列模型的基石,它们被整合到神经网络架构中 例如Transformer
其实上面的公式1或2中的SSM便是一个序列变换,且 P = 1 当然,它可以通过简单地在此维度上来推广到 P > 1(换句话说,将输入视为 P 个独立序列并对每个序列应用SSM,即可以将 P视为一个头维度) - 定义 2.2 定义SSM 操作符作为序列变换,由上面的公式2定义
在 SSM 中, N维度是一个称为状态大小或状态维度的自由参数,也称之为状态扩展因子,因为它将输入/输出的大小扩展了 𝑁倍,这对这些模型的计算效率有影响
(其实许多类型的序列变换,例如注意力机制,都可以表示为跨序列维度的单一矩阵乘法) - 定义 2.3 如果一个序列变换可以写成形式,其中是一个依赖于参数𝜃的矩阵,称其为矩阵变换,且用矩阵𝑀来表示序列变换
当然,在上下文明确时,通常省略对的依赖
1.2 一系列定义:注意力机制、结构化矩阵、SSD
1.2.1 线性注意力机制的定义
注意力机制已经非常经典了(如果还不熟悉注意力机制的,请参见此文:Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT),屡见不鲜,其为序列中每对位置分配分数,使每个元素能够“关注”其余部分
迄今为止,最常见和最重要的注意力机制变体是softmax自注意力机制,其定义如下
对于,由于注意力机制需要一次次计算两两token之间的注意力(毕竟有这个计算),导致了二次方的计算复杂度
为了降低二次方的复杂度,已经提出了许多注意力的变体,其中最重要的变体是线性注意力(详见此文的2.2.1 什么是线性transformer:Transformers are RNNs与cosformer)
- 粗略地说,这类方法通过将softmax折叠到核特征映射中,并利用矩阵乘法的结合性将注意力计算中的矩阵左乘改成右乘,即
- 如下图右侧所示,将QKV的左乘变成右乘后,从⽽将理论计算复杂度降为线性「更多详见此文《七月论文审稿GPT第1版:通过3万多篇paper和10多万的review数据微调RWKV》的2.2节」
值得一提的是
- 提出线性注意力的这个标题:
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention「作者:A Katharopoulos · 2020」 - 是否与提出mamba2的论文标题:
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
有着很高的相似性呢
再进一步,既然transformer是RNN,而SSM某种意义上也是RNN,那mamba2和transformer是否有着直接的联系?不急,请继续看下文的讲解
- 此外,在因果(自回归)注意力的重要情况下,他们表明,当因果掩码被合并到左侧作为,其中是下三角1矩阵时,右侧可以扩展为递归(Moreover, in the important case of causal (autoregressive) attention, they show that when the causal mask is incorporated into the left-hand side as (𝐿 ◦ 𝑄𝐾⊤) · 𝑉 , where 𝐿 is the lower-triangular 1’s matrix, then the right-hand side can be expanded as a recurrence)
这个的作用在于确保在计算注意力权重时,每个位置只能看到它之前的位置(类似GPT做预训练预测下一个token时,必会遮住当前token的后续token,不然就无所谓预测了) - 最近的一些工作,如RetNet(Y. Sun等,2023)和GateLoop(Katsch 2023)将其加强为更一般形式的
为了方便大家更好的理解上面这段话,我再给大家举个具体的矩阵例子,以形象说明
- 定义查询Q、键K、值V矩阵,为了简化,可以使用随机的矩阵值来表示它们
- 接下来,计算
- 现在,我们定义一个下三角1矩阵 ,用于实现因果掩码
- 接下来,计算
- 最后,再计算
1.2.2 结构化矩阵(Structured Matrices)的定义:方便做矩阵乘法
一般矩阵需要 个参数来表示,并且执行诸如矩阵-向量乘法等基本操作需要时间。而所谓的结构化矩阵是指那些
- 可以压缩表示,比如在亚二次(理想情况下是线性)参数中表示
- 并且通过快速算法(最重要的是矩阵乘法),直接操作这种压缩表示
也许最典型的结构化矩阵家族是稀疏矩阵和低秩矩阵。 然而,还存在许多其他家族,例如Toeplitz矩阵、Cauchy矩阵、Vandermonde矩阵和蝶形矩阵
1.2.3 SSD(结构化状态空间对偶)的定义:注意力矩阵乘以掩码矩阵
状态空间对偶(SSD)层可以定义为选择性SSM(如之前公式2所示)的特例
可以应用SSM作为递归(或并行扫描)的标准计算,其在序列长度上具有线性复杂度。 与Mamba中使用的版本相比,SSD有两个小的不同点:
- 的结构从对角线进一步简化为标量乘以单位矩阵结构。 在这种情况下,每个也可以仅用一个标量来表示
- 使用了更大的头维度 ,相比于Mamba1中使用的 P = 1,通常选择,而Transformer一般也会这样设置头的维度
与原始选择性SSM相比,这些变化可以被视为在略微降低表达能力的同时 显著提高训练效率。 特别是,新算法将允许在现代加速器上使用矩阵乘法单元
如下图所示
- 原论文Sec.3中的Semiseparable Matrices——半可分矩阵,将揭示结构化矩阵与SSM之间的联系
- 原论文Sec.4中的Structured Masked Attention(SMA),将揭示结构化矩阵与注意力之间的联系
- 原论文Sec.5中的State SpaceDuality(SSD),将揭示SSM与注意力之间的联系,如此,基于SSD,便发展出来了mamba2
更进一步,SSD的对偶形式是一种与注意力密切相关的平方计算,其定义为
其中是依赖于输入的标量,范围在 [0, 1]之间
SSD与标准的softmax注意力相比,有两个主要区别
- 去掉了softmax
- 注意力矩阵按元素乘以一个额外的掩码矩阵
这两种变化都可以被视为解决了原始注意力中的问题。 例如,有研究发现softmax在注意力分数中会引起问题,如“注意力陷阱”现象(Darcet等,2024;Xiao等,2024)
更重要的是,掩码矩阵可以被视为用不同的数据依赖位置掩码替换Transformer的启发式位置嵌入,从而控制跨时间传递的信息量(the mask matrix 𝐿 can be viewed as replacing the heuristic positional embeddings of Transformers with a different data-dependent positional mask that controls how much information is transfered across time)
更广泛地说,这种形式是下文定义的线性注意力的SMA泛化的一个实例
- 总之,通过展示SSM具有矩阵变换形式,对于一个依赖于的矩阵,各种形式的SSD可以通过统一的矩阵表示连接起来
- 特别地,SSD的对偶形式等价于通过矩阵 𝑀进行的朴素(平方时间)乘法,而递归形式是一种利用 𝑀结构的特定高效(线性时间)算法
以上之外,任何用于乘以的算法都可以应用,此次提出的硬件高效SSD算法是一种新的结构化矩阵乘法方法,一方面,其涉及 𝑀的块分解,比纯线性或二次形式获得更好的效率权衡;二方面,与一般选择性SSM——mamba1(Gu和Dao 2023,即Albert Gu and Tri Dao. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”)相比,它相对简单且易于实现
第二部分 从SSM是Structured Matrices、使用结构化矩阵推广线性注意力到SSD
2.1 SSM是结构化矩阵:State Space Models are Structured Matrices(含公式3 4 5 6 7 8)
2.1.1 SSM的矩阵变换形式:状态乘以矩阵来生成,再用表示中的系数(公式3)
回顾一下,对选择性SSM——即mamba1的定义是通过之前的公式2定义的参数化映射
SSM中,有
根据定义,,通过归纳法,可知时刻 的状态 ,可以表示为之前各个时刻的状态的加权和,即如下
上述公式中的
- 第一行的每一项表示的是之前某个时刻 的状态 经过一系列线性变换后的结果,最后这些结果加在一起得到了当前时刻的状态
- 第二行中的表示从一直乘到
为方便大伙一目了然,加之十多年前,我就提醒自己,写博客的目标之一是 如果某个算法看别的资料看不懂、看不动,那可以看懂、看动我的(坚持10多年来了,好处是博客影响力巨大,不好是累人),故还是要不厌其烦的解释下
其实理解上面那个公式很简单,直接一步一步推导一下即可,如下所示便可一目了然
- t = 0时
在这种情况下,是单位矩阵,故有- t = 1时
- t = 2时
- t = 3时
当然,这个过程可以借鉴下mamba1的这个图,只是、还没加上这个参数而已,所以在不同的输入x之下,便不会存在不同的、(下图来自mamba1解读一文的第2.1.2节SSM的循环结构表示:方便快速推理)
通过乘以矩阵来生成并将方程在上向量化,可推导出SSM的矩阵变换形式,如下(称之为公式3)
对于上述公式3,我举个例子,比如因为有
故可得
好比
2.1.2 半可分矩阵(定义3.1、定义3.2/公式4、定义3.3/公式5、定义3.4):顺序半可分、1-半可分矩阵
首先,先来看下半可分矩阵(Semiseparable Matrices)的定义「称之为定义3.1,在有的文献中也被称为 (N, 0)-半可分性)」
一个(下三角)矩阵 𝑀是 N-半可分的,如果包含在下三角部分(即对角线或以下)的每个子矩阵的秩——Rank最多为 N,则称 N为半可分矩阵的阶数或秩
Definition 3.1. A (lower triangular) matrix 𝑀 is N-semiseparable if every submatrix contained in the lower triangular portion
(i.e. on or below the diagonal) has rank at most N. We call N the order or rank of the semiseparable matrix
其和其他形式的相关“可分”结构(例如准可分矩阵和其他半可分矩阵的定义)有时被称为结构化秩矩阵(或秩结构矩阵),因为它们的子矩阵由秩条件表征
半可分矩阵有许多结构化表示,包括分层半可分HSS、顺序半可分SSS和Bruhat形式(Pernet和Storjohann 2018),此处将主要使用SSS形式
2.1.2.1 顺序半可分SSS表示「The Sequentially Semiseparable (SSS) Representat」:每个 N-半可分矩阵都有一个 N-SSS 表示
先看顺序半可分矩阵SSS表示的定义(称其为定义3.2,公式4)
一个下三角矩阵 𝑀 ∈ R(T,T)如果它可以写成以下形式,则具有 N-顺序半可分(SSS)表示
对于向量和矩阵,定义算子 SSS使得
换言之,如果是且(其实和上面的但一个意思),使得
则相当于
且和之前的公式3,一个意思
这个SSS表示带来的好处是如定义3.3所示
一个 N-SSS 矩阵 𝑀具有上面公式(4)的表示,则便是 N-半可分的「Lemma 3.3 An N-SSS matrix 𝑀 with representation (4) is N-semiseparable」
证明如下(定义为公式5)
考虑任何非对角块,其中 𝑗 ′ > 𝑗 ≥ 𝑖 > 𝑖′「如原论文中所说,Consider any off-diagonal block 𝑀𝑗:𝑗 ′,𝑖′:𝑖 where 𝑗′ > 𝑗 ≥ 𝑖 > 𝑖′」,这具有显式的秩-N分解为
为了避免正在阅读此文的你头疼,我还是用一个具体的示例来形象的说明下上述公式5
- 假设有以下矩阵
- 根据公式5的结构,选择j' = 2、j = 1、i = 1、i’ = 0,然后有:
如此,上述的相当于有4个式子的结果需要逐一计算,具体计算过程如下步骤3 4 5 6所示- 先算左上角那个式子的结果,可得
先算前两项
再算前两项与第三项的结果- 再计算右上角那个式子的结果
由于有
故而有- 接下来,计算左下角那个式子的结果
假设与相同,则
且假设与相同,则- 最后,计算右下角那个式子的结果
由于有
且有
则可得- 最终,将上面这些结果全部合并起来,则可以得到矩阵
且有定义 3.4 即每个 N-半可分矩阵都有一个 N-SSS 表示
2.1.2.2 1-半可分矩阵(标量SSM递归):许多序列模型算法可以归结为结构化矩阵乘法算法
首先,注意1-Semiseparable Matrices会简称1-SS矩阵
接下来,列出1-SS矩阵的特殊情况,此时和是标量,可以从SSS的表示(如上面的公式4所示)或之前这个图所示的
公式4当中出现的
即中提取出来或——去掉和
原因在于对角矩阵易于处理(例如,对角矩阵的乘法与元素级标量乘法相同),故可以忽略这些项
因此,对1-SS矩阵的基本表示是「想曾经在上面的公式3或公式4中,还是」,或如下(定义为公式6)
其等同于标量递归的最小形式——即状态维度 且没有投影的退化SSM情况
值得注意的是,矩阵乘法可以通过如下的式子进行递归计算(定义为公式7)
即
对于上面的
相当于把之前公式3中「其根据计算而来」
中的和都去掉
因此,也将1-SS矩阵的矩阵乘法称为标量SSM递归或累积乘积和(累积乘积和的广义形式)作为递归的基本形式,同时也是本次mamba2主要算法的构建模块
也从侧面说明,许多序列模型的算法可以归结为结构化矩阵乘法算法。1-SS矩阵体现了这一联系:有许多快速算法可以计算原始标量递归或cumprod sum算子,所有这些算法实际上都等价于1-SS矩阵的不同结构分解
2.1.3 SSM是半可分矩阵:使得SSM问题转化为结构化矩阵乘法
回顾一下,我们对SSM的定义是通过定义2.1定义的参数化映射
SSM与半可分矩阵之间的联系仅仅是通过将这种变换写成矩阵乘法,将向量
- 公式(3)
直接建立了SSM与顺序半可分表示之间的联系
而“顺序半可分表示”又等价于一般的半可分矩阵(定义3.3和定义3.4) - 定义 3.5 SSM变换具有状态大小 N,等同于按顺序半可分表示的 N-SS 矩阵的矩阵乘法
换句话说,序列变换算子SSM(定义 2.2)
与矩阵构造算子 SSS(定义3.2) 一致
可以互换使用它们(有时也用SS作为简写)
此外,巧合的是 结构化SSM和顺序半可分矩阵具有相同的缩写,强调了它们的等价性
且可以使用这些缩写 SSM(状态空间模型或半可分矩阵)、SSS(结构化状态空间或顺序半可分) 或 SS(状态空间或半可分)互换使用,以明确地指代任一概念
当然,最终的约定一般是:SSM指状态空间模型,SS指半可分,SSS指顺序半可分
如下图所示,说明了将SSM视为「半可分矩阵——Semiseparable Matrix Transformations」的序列变换视角
- 作为序列变换,SSM可以表示为作用于序列维度T上的矩阵变换𝑀∈R(T,T),在一个头的每个通道中共享相同的矩阵(如上图左侧所示)
- 这个矩阵是一个半可分矩阵(如上图右侧所示),它是一个秩结构矩阵,其中包含在对角线及其以下的每个子矩阵(蓝色)的秩最多为N,等于SSM的状态维度
这个意味着所有计算SSM的算法都可以看作是对半可分矩阵进行结构化矩阵乘法的算法,总之,上面的定义3.5 使得可以将高效计算SSM(及其他序列模型)的问题转化为高效的结构化矩阵乘法算法
补充一句,上图中的右侧——半可分矩阵,其实就是类似之前公式4当中的图
2.1.4(选读) 通过结构化「矩阵乘法计算SSM(含公式8)」
既然上文已经证明了SSM的计算可以转化为结构化矩阵乘法,那接下来,咱们便通过结构化矩阵算法计算SSM
如前所述,半可分矩阵(即秩结构矩阵)是一种经典的结构化矩阵类型:
- 它们具有压缩表示形式,例如SSS形式只有参数,而不是参数
- 它们有直接在压缩表示上操作的快速算法
此外,参数化和矩阵乘法成本在半可分阶中可以非常紧凑
定义3.6 (Pernet, Signargout, 和 Villard (2023))表示:一个 N-SS 矩阵大小为 T可以用 𝑂 (NT)参数表示,并且矩阵-向量乘法在时间和空间上的复杂度为𝑂 (NT)
例如,1-SS 矩阵说明了这种连接的本质。 矩阵 𝑀 = 1SS(𝑎)由正好 T − 1 个参数
回顾一下上面提到过的公式6
并且可以通过遵循上文提过的标量递归公式7在 𝑂 (T)时间内计算
公式7
根据上面的定义3.6可知,只需利用公式(2)
展开递归即可,具体过程如下公式8所示(三个公式分别被定义为8a、8b、8c)
这里, 𝐿 ∈ R(T,T)被定义为 1SS(𝐴),换句话说对于𝑖∈ [N]
该算法涉及三个步骤,对应于上文的公式2:
- 通过输入矩阵 𝐵 (8a)扩展输入 𝑋
- 展开独立的标量SSM递归 (8b),且在步骤(8b)中使用了标量SSM和1-SS矩阵之间的等价关系
- 通过输出矩阵 𝐶 (8c)收缩隐藏状态 𝐻
其实,整个公式8算是mamba1(S6)模型的一个特例,其中扩展的张量Z和H的大小为
2.1.5(选读) N-半可分矩阵的意义:通过低秩且分块的方法加速注意力的计算过程
为方便大家更好的理解,我举一个N-半可分矩阵来实现注意力计算的带有token、序列的完整例子,以让大家最大程度的一目了然,更坚定大家继续读下去的信心与决心(据我了解,这种完整示例,截止到24年9月9之前,还不曾出现过)
众所周知,在标准的注意力机制中,计算 是一个重要的步骤,其中:
- Q是 Query 矩阵,维度为 n×d,其中 n 是序列长度(token数),d 是嵌入维度
- K 是 Key 矩阵,维度为 n×d
直接计算 的复杂度为,因为每个token的Query需要与所有token的Key计算点积。对于长序列,这种计算开销非常大,尤其在实际应用中,如语言模型处理大规模文本数据时
为了降低这种计算复杂度,可以引入一种 结构化的N-半可分矩阵,使得通过与这个结构化矩阵相乘的方式来 近似 或 替代 直接计算的操作
假设我们有一个结构化的N-半可分矩阵 SSS,可以用于加速或简化的计算。这个矩阵 SSS 的作用有两种可能:
- 方式1:低秩近似
在很多实际问题中,矩阵 Q 和 K 可能存在冗余信息,或者数据的某些维度比其他维度更重要。引入N-半可分矩阵可以将 的高维度操作分解为多个低维度的操作
例如:如果 SSS 是一个低秩矩阵,它可以近似表示 ,然后我们计算 代替 ,从而降低计算复杂度——减少点积计算中的维度,进而加速计算 - 方式2:分块计算
另一种方式是,N-半可分矩阵 SSS 可以将 Q 和 K 的计算过程分成多个独立的小块。换句话说,N-半可分矩阵将Q 和 K 分成若干个较小的子矩阵,通过分块计算代替整体的矩阵乘法操作
例如,假设 SSS 将 K 分解为两个较小的子矩阵 和 ,那么我们可以分步计算和 ,再将结果组合起来,这种方式可以显著减少计算的复杂度
举个例子,假设我们有一个序列:"Who is the July Online founder",这将被模型转化为一组token:[who, is , the, July, Online, founder],每个token会被嵌入到一个向量空间中,从而生成一个二维矩阵
- 我们将每个token的嵌入向量看作是矩阵的行(假设每个嵌入向量的维度是3),因此我们会得到一个形状为 6×3 的矩阵,其中7是token的数量,3是嵌入维度
假设每个token的嵌入向量为Who is the July Online founder Q = [ [0.1, 0.2, 0.3], [0.5, 0.6, 0.7], [0.9, 1.0, 1.1], [1.3, 1.4, 1.5], [1.7, 1.8, 1.9], [2.1, 2.2, 2.3] ]
- 在标准的点积注意力计算中,首先有三个矩阵:Query、Key和 Value——即输入嵌入与「权重矩阵//」计算后的矩阵
假设近似 Query, Key, 和 Value 矩阵的三个半可分矩阵分别为:
Q (6x3):表示每个token的查询向量
K (6x3):表示每个token的键向量
V (6x3):表示每个token的值向量
为何可以找到三个半可分矩阵来近似呢,原因在于在许多实际问题中,数据的特征矩阵(如Q, K, V)可能具有低秩性质,也就是矩阵中的信息主要集中在某些低维子空间中。换句话说,数据中许多特征可能存在一定的相关性,可以通过较少的特征来近似表示。因此,可以假设这些矩阵是N-半可分的,且还可以将其分解为若干子矩阵 - 既然这三个矩阵 Q, K, 和 V 可以N-半可分的,意思是它们可以被分解成若干个低维子矩阵
比如Q, K, 和 V 矩阵中的每个token的向量可以分解为两个子向量
例如:对于每个token的嵌入向量,假设它们可以分解为两个部分:前两个维度是一个子向量,最后一个维度是另一个子向量
那么矩阵 Q 可以表示为:
Q1 = 第一列和第二列组成的子矩阵
Q2 = 第三列组成的子矩阵Q_1 = [ [0.1, 0.2], [0.5, 0.6], [0.9, 1.0], [1.3, 1.4], [1.7, 1.8], [2.1, 2.2] ]
类似地,K 和 V 也可以这样分解:Q_2 = [ [0.3], [0.7], [1.1], [1.5], [1.9], [2.3] ]
K1 和V1 表示前两个维度的子矩阵
K2 和V2 表示最后一个维度的子矩阵 - 接下来,直接计算对应的注意力
- 计算 和 ,得到两个部分的注意力权重
- 对这两个结果分别进行softmax,再将它们合并起来,得到最终的注意力权重
- 加权V矩阵: 最后,我们用分离出来的注意力权重对 V1和 V2 分别进行加权和计算,最终将结果拼接起来,形成最终的输出
2.2 SMA(结构化掩码注意力):使用结构化矩阵推广线性注意力
2.2.1 从自注意力、核注意力到掩码(核)注意力:含公式9-13
注意力的基本形式(单头)是对三个向量序列的映射 (𝑄, 𝐾, 𝑉) ↦ →𝑌,如下所示(定义为公式9)
可以使用““shape annotation”来表示张量的维度,例如 𝑄 ∈ R(T,N),其中
- S和 T表示源和目标序列长度,分别意指:source、target之意
- N表示特征维度
- P表示头维度
最常见的softmax注意力变体使用softmax激活 𝑓 = softmax来规范 𝐺矩阵的行
此外
- 虽然注意力通常被框定为对这三个对称视图输入𝑄, 𝐾, 𝑉的操作,但公式9 中的输入和输出维度表明情况并非如此(特别是,输出中不存在特征维度 N时)
- 因此在 S = T(例如自注意力)的情况下,将 𝑉视为主要输入,因此公式9 定义了一个适当的序列变换 𝑉 → 𝑌
2.2.1.1 自注意力
对于自注意力,其中
- (i) 源序列和目标序列相同(即 S = T)
- (ii) 通常特征维度和头维度相同(即 N = P)
- (iii) 并且𝑄, 𝐾, 𝑉是通过对同一输入向量的线性投影生成的,即
2.2.1.2 核注意力
// 待更
2.2.1.3 掩码(核)注意力:对公式10的分解
设 𝐿为形状为 (T, S)的掩码。 最常见的是,在自回归自注意力情况下,当 S = T时, 𝐿可能是一个下三角矩阵,表示因果掩码
除了强制因果关系外,还可以应用许多其他类型的掩码——特别是各种稀疏模式,如带状、扩展
或块对角线——这些都是为了减少密集注意力的复杂性
掩码注意力通常用矩阵表示法表示为「定义为公式10,如果你读的细致的话,你会发现这个公式10其实早在本文的《1.2 一系列定义:注意力机制、结构化矩阵、SSD》,便已出现过」
更准确地说,带有shape annotation并将其分解为精确的计算序列(定义为公式11):
在本节中改进的注意力变体推导从注意到这个公式可以写成一个单一收缩开始(定义为公式12):
而算法11可以通过特定的成对收缩顺序重新表述为算法12的形式,如下公式13所示
2.2.2 线性注意力:含公式14、公式15(SMA的线性对偶形式)
如下公式14所示的线性注意力
等价于10:
接下来,以另一种顺序执行上面的公式12,从而得到下面的公式15——算是SMA的线性对偶形式
其中
- 第一步(15a)通过特征维度 N的因子执行“扩展”到更多特征
- 第二步(15b)是最关键的,并解释了线性注意力的线性部分
首先,注意到(15b)只是通过 𝐿进行直接矩阵乘法「因为 (P, N)轴可以被展平」
且还要注意,这是唯一涉及 T和 S轴的项,因此应该具有 Ω(TS)复杂度(即序列长度的二次方)
然而,当掩码 𝐿是标准的因果注意力掩码(下三角全为1)时,通过 𝐿进行矩阵-向量乘法与特征逐项累积和相同
为方便理解,可再回顾下公式7
而怎么来呢,根据而来呀,而不就是公式6么 - 第三步(15c)收缩扩展的特征维度。 如果将 𝐾视为输入(如上文2.2.1节开头所述),那么 𝑉和 𝑄分别执行扩展和收缩
2.2.3 SMA(结构化掩码注意力):可实例化为任何给定的矩阵结构类别
通过掩码注意力的张量收缩视角(如公式15所示),得知原始线性注意力的关键在于带有因果掩码的矩阵-向量乘法等同于累加求和运算(we can immediately see that the crux of the original linear attention is the fact that matrix-vector multiplication by the causal mask is equivalent to the cumulative sum operator)
- 然而,观察到没有理由注意力掩码必须全是1。 线性注意力快速的必要条件是 𝐿是一个结构化矩阵,根据定义,这些矩阵具有快速矩阵乘法(根据上文1.2.2节所述的结构化矩阵 所述)
- 特别是,我们可以使用任何矩阵-向量乘法复杂度低于二次方(理想情况下是线性)的掩码矩阵「we can use any mask matrix 𝐿 that has sub-quadratic (ideally linear) matrix-vector multiplicat」,这将通过加速瓶颈公式(15b)使其具有与标准线性注意力相同的复杂度
定义 4.2 结构化掩码注意力SMA(或简称结构化注意力)被定义为一个函数作用于查询/键/值𝑄, 𝐾, 𝑉以及任何结构化矩阵 𝐿 (即具有低于二次复杂度的矩阵乘法—— sub-quadratic matrix multiplication),通过四维张量收缩
- SMA二次模式算法是通过(公式13)定义的成对收缩序列,对应于标准的(掩码)注意力计算
- SMA线性模式算法是通过(公式15)定义的成对收缩序列,其中步骤(15b)通过二次结构矩阵乘法进行优化
总之,可以将SMA实例化为任何给定的矩阵结构类别,比如如下图所示的一些实例「SMA constructs a masked attention matrix(掩码注意力矩阵) for any structured matrix 𝐿, which defines a matrix sequence transformation 𝑌 = 𝑀𝑉」
- 线性注意力使用因果掩码
- RetNet使用衰减掩码,其中,对于某些衰减因子
RetNet (Y. Sun et al. 2023) uses a decay mask 𝐿𝑖 𝑗 = 𝛾𝑖 − 𝑗 · I[ 𝑗 ≥ 𝑖] for some decay factor 𝛾 ∈ [0, 1] - SSD使用1-半可分(1-semiseparable)
- 衰减掩码可以推广到Toeplitz矩阵对于某些可学习的(或依赖于输入的)参数集
这可以解释为一种相对位置编码形式,类似于其他方法如AliBi,但乘法而不是加法 - 另一种变体可以使用傅里叶矩阵(Fourier matrix)以不同的方式编码位置结构
2.3 总结:再谈SSD(状态空间对偶性):含公式16之SSM的二次对偶形式
2.3.1 标量-恒等的结构化状态空间模型及其示例
回想一下,SSM由定义,SSM的矩阵形式使用SSS(顺序半可分)表示,其中公式3
现在让我们考虑只是一个标量的情况;换句话说,这是一种结构化 SSM 的实例,其中 矩阵具有极其特殊的结构:,其中 是一个标量, 是单位矩阵
然后可以重新排列
这可以向量化为
其中, A 的特性(这里是标量 a)被用来构建 L,L 是一个由 a 定义的序列,用于表示状态转移, 而𝐵, 𝐶 ∈ R(T,N)
使用这种公式,完整的输出 𝑌 = 𝑀X精确计算为公式16——算是SSM的二次对偶形式(也可以认为是SMA的二次对偶形式)
其中 S = T,从而可以看到这与掩码核注意力公式13的原始定义完全相同
因此,如「第2.1.4 通过结构化矩阵算法计算SSM」所述,计算标量结构化SSM——通过实现半可分矩阵𝑀并执行二次矩阵-向量乘法——与二次掩码核注意力完全相同
为了更好的理解上述过程,我再举一个详细的示例
- 首先,定义系统
假设我们有一个二维系统,其状态由两个变量 和 描述,我们的目标是描述这个系统如何从当前状态 转移到下一个状态
系统的状态转移和输出方程定义如下:
其中:
是在时间 t的系统状态
是状态转移矩阵
是控制输入矩阵
是输出矩阵
是在时间 t 的外部输入- 其次,定义一系列矩阵和参数
为了简化,我们假设:
,其中 是一个标量,是单位矩阵
是控制输入矩阵
是输出矩阵
是一个常数输入
且定义初始条件
是初始状态- 接着,做状态转移计算
使用 , 和 的定义来计算 和
时间 到 :
时间 到 :- 然后,使用矩阵
在 的特殊情况下,我们可以构建 矩阵,其中 。对于 到 , 矩阵为():- 最后,计算M矩阵
这个 矩阵描述了从输入 到输出 的映射,通过 可以计算出输出
2.3.2 1-半可分结构化掩码注意力
SMA允许使用任何结构化掩码
当是因果掩码时,它是标准的线性注意力。 注意,因果掩码是,即1-SS掩码由公式6中的生成
进一步,对于,而言,其非常类似于注意力计算
- 事实上,如果所有的
- 那么只是下三角因果掩码且等同于因果线性注意力(then is simply the lower-triangular causal mask and is equivalent to causal linear attention)
- 而其中的这不就相当于
相当于C B X类比于Q K V
毕竟,可曾还记得上面的公式3
这激发了将推广到1-半可分掩码类,或1-半可分结构化
掩码注意力(1-SS SMA),其中线性注意力递归中的cumsum被更一般的递归——标量SSM扫描,即1-半可分矩阵乘法所取代
最后,我们考虑1-半可分SMA的最重要原因是计算它的线性形式是对角SSM的一个特例。SMA的线性形式是算法(15),其中瓶颈步骤(15b)可以看作是通过1-SS掩码进行矩阵乘法
第三部分 从硬件高效的SSD算法、到Mamba-2 架构
3.1 硬件高效的SSD算法:块分解、对角块、低秩块
定义6.1 考虑一个具有状态扩展因子 N和头部维度 P = N的SSD模型,存在一种算法可以在任何输入上计算模型,该算法只需要训练FLOPs,推理FLOPs,推理内存,其工作主要由矩阵乘法主导
注意,所有这些界限都是紧的,因为具有状态扩展 N的SSM在头部大小为时,总状态大小为 「分别得出训练和推理 FLOPs 的下界为和 」。此外,输入本身有个元素,从而产生了内存下限
如下图所示,状态空间对偶描述了SSM和掩码注意力之间的密切关系
- 上图左侧:一般的 SSM和 SMA 都具有线性和二次形式,在符号上有直接的类比
比如,SSM的线性形式为公式8b,SSM的二次对偶形式为公式16 对于上面的公式16如果是作为SSM的二次对偶形式:在SSM的框架下,公式16描述了如何通过状态转移矩阵 A、输入到状态的映射矩阵 B和状态到输出的映射矩阵 C 来计算序列的输出
如果是作为SMA的二次对偶形式:在SMA的框架下,公式16展示了如何通过引入结构化矩阵(例如,通过SSM定义的矩阵)来优化传统的注意力计算。这种方法允许使用掩码矩阵 L 来控制信息流,并通过结构化矩阵与查询 Q、键 K 和值 V 矩阵的结合来加速计算 (注意,标红这句话值得反复体会三遍)
再比如,SMA的线性形式为公式15、SMA的二次形式为公式13a(当然,这个13a由公式10衍变而来) - 上图右侧:SSM 和 SMA 在一大类状态空间对偶模型(SSD) 上相交,这些模型捕捉了许多序列模型作为特例
定义6.1背后的主要思想是再次将计算SSM的问题视为半可分矩阵乘法,但以一种新的方式利用其结构,即不是在递归或注意模式下计算整个矩阵,而是对矩阵进行块分解
- 对角块可以使用对偶注意模式计算,这可以通过矩阵乘法高效完成
- 而非对角块可以通过半可分矩阵的秩结构进行分解并简化为较小的递归
背景铺垫:块分解
首先,将矩阵 𝑀划分为一个的子矩阵网格,每个子矩阵的大小为 Q × Q,对于某个块大小 Q。 注意,根据半可分矩阵的定义性质(定义3.1),非对角块是低秩的
如下图所示,分别体现的是块分解、对角块、低秩块
举个例子,例如对于 T = 9 并分解成长度为 Q = 3 的块
上图中的阴影部分是半可分矩阵的非对角块的低秩分
从这里我们可以将问题简化为这两个部分。 这些也可以解释为将“块” 的输出分为两个部分:
- 块内输入的影响
- 以及块之前输入的影响
然后,如果要完成状态空间对偶(SSD)模型的完整 PyTorch代码,则可以先定义符号来定义批量矩阵乘法与批次维度 B
从而可以推断出效率的三个方面:
- 计算成本:总共FLOPs
- 内存成本:总共空间
- 并行化:更大的 M, N, K项可以利用现代加速器上的专用矩阵乘法单元
def segsum(x):
"""朴素的段和计算。exp(segsum(A)) 生成一个 1-SS 矩阵,等价于一个标量 SSM """
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(X, A, B, C, block_len=64, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# 重新排列成块/段
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1
为方便形象理解,再贴个图,如下(来源于此文中的3.1.1.3节“mamba:从S4到S6的算法变化流程”)
3.1.1 对角块
对角块很容易处理,因为它们只是较小规模的自相似问题。 𝑗-th 块表示计算范围内的答案
- 特别地,对于小块长度 Q,这个问题可以通过对偶二次SMA形式更有效地计算
其中,二次SMA计算的成本包括三个步骤「Center Blocks. The cost of the quadratic SMA computation consists of three steps (equation (16)),至于为何是公式16,上文早已着重解释过了,在于公式16可以认为是SSM的二次对偶形式,也可以理解为SMA的二次对偶形式」: i) 计算核矩阵,其成本为 BMM( T/Q, Q, Q, N)
ii) 乘以掩码矩阵,这是对形状为 ( T/Q, Q, Q)的张量进行的逐元素操作
iii) 乘以 𝑋值,其成本为 BMM( T/Q, Q, P, N)
此外,这些块可以并行计算 - 这些子问题可以解释为:假设初始状态(到块)为 0,每块的输出是什么。换句话说,对于块 𝑗,这将计算正确的输出,仅考虑块输入
对应的代码为
# 1. 计算每个块内(对角块)的输出
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
3.1.2 低秩块:右B-块因子、中心A-块因子、左C-块因子三个部分的结算
低秩分解由3个项组成,相应地有三部分计算
- 像下面这样的项被称为右因子或 𝐵-块因子
此步骤计算低秩分解的右 𝐵-块因子的乘法。 注意,对于每个块,这是一个(N, Q)乘(Q, P)的矩阵乘法,
其中 N是状态维度, 𝑃是头维度。 每个块的结果是一个(N, P)张量,其维度与扩展的隐藏状态ℎ相同
这可以解释为:假设初始状态(到块)为 0,每个块的最终状态是什么。 换句话说,这计算了 ,其中
对应的代码为
这一步是一个单一的矩阵乘法,成本为 BMM( T/Q, N, P, Q)# 2. 计算每个块内的状态 # (低秩分解的非对角块的右项;B项) decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
- 像这样的项被称为中心因子或 𝐴-块因子
这一步计算了低秩分解中中心 𝐴-块因子项的影响。 在前一步中,每个块的最终状态的总形状为
现在通过一个由这现生成的1-SS矩阵相乘:
这一步可以通过任何用于计算1-SS乘法的算法来计算(也称为标量SSM扫描或累积乘积和操作符)
这可以解释为:每个块的实际最终状态是什么考虑到所有先前的输入; 换句话说,这计算了真实的隐藏状态(考虑到所有的)
对应的代码为
这一步是长度为 T/Q的标量SSM扫描(或1-SS乘法),在 (N, P)独立通道上进行。 这次扫描的工作是 TNP/Q,这是相对于其他因素可以忽略不计的# 3. 计算块间SSM递归;在块边界生成正确的SSM状态 # (非对角块分解的中间项;A项) if initial_states is None : initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1]
请注意,由于阻塞将序列长度从 T减少到 T/Q,这次扫描的成本比纯SSM扫描(例如Mamba的选择性扫描)小 Q倍
因此,我们观察到在大多数问题长度上,其他算法(附录B)可能更有效或更容易实现,而不会显著减慢速度
例如,通过1-SS矩阵乘法的简单实现成本为 BMM(1, T/Q, NP, T/Q),这比简单的递归/扫描实现更容易实现且可能更有效 - 像下面这样的项被称为左因子或 𝐶-块因子
这一步计算了左 𝐶-块因子的低秩分解的乘法。 对于每个块,这可以通过矩阵乘法来表示
这可以解释为:每个块的输出是什么考虑到正确的初始状态,并假设输入为 0
换句话说,对于块 𝑗,这计算了仅考虑先前输入的正确输出
对应的代码为
这一步是一个单一的矩阵乘法,成本为 BMM(T/Q, Q, P, N)# 4. 计算每个块的状态到输出的转换 # 低秩分解的非对角块的左项;C项 state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum( 'bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out
最后,如下# 添加块内和块间项的输出(对角块和非对角块) Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") return Y, final_state
整个的过程,可以用下图表示
其中
- 橙色的diagonal block代表input到output,涉及到上面所说的对角块——对角块表示块内计算
而下面的非对角块表示则通过SSM的隐藏状态进行的块间计算
- 绿色的low-rank block代表input到state,类似中的
相当于上面介绍过的右因子或 𝐵-块因子 - 黄色的low-rank block代表state到state,类似中的
相当于上面介绍过的中心因子或 𝐴-块因子 - 蓝色的low-rank block代表state到output,类似
相当于上面介绍过的左因子或 𝐶-块因子
总之,通过使用SSM的矩阵变换视角将其写成半可分矩阵,通过块分解矩阵乘法算法开发了更硬件高效的SSD模型计算,矩阵乘法也可以解释为SSM,其中块表示输入和输出序列的分块
注意,上图可以配合下图一块看(来源于此文中2.1.1节的“离散数据的连续化:基于零阶保持技术做连续化并采样”的最后)
// 待更
3.2 Mamba-2 架构
如下图所示,Mamba-2模块通过去除序列线性投影简化了Mamba模块(The Mamba-2 block simplifies the Mamba block by removing sequential linear projections)
- SSM参数𝐴, 𝐵, 𝐶在模块开始时生成,而不是作为SSM输入𝑋 的函数
the SSM parameters 𝐴, 𝐵, 𝐶 are produced at the beginning of the block instead of as a function of the SSM input 𝑋 . - 添加了一个额外的归一化层,如NormFormer,提高了稳定性
An additional normalization layer is added as in NormFormer (Shleifer, Weston, and Ott 2021), improving stability. - 𝐵和𝐶投影只有一个头部,在𝑋头部之间共享,类似于多值注意力(MVA)
The 𝐵 and 𝐶 projections only have a single head shared across the 𝑋 heads, analogous to multi-value attention (MVA)
3.2.1 模块设计:并行参数投影、额外的归一化
我们首先讨论对神经网络模块的修改,这些修改独立于内部序列混合层(即核心SSD层之外)
3.2.1.1 并行参数投影
对比mamba1和mamba2可知
- Mamba-1的动机是基于SSM中心的观点,其中选择性SSM层被视为从 𝑋 → 𝑌的映射(Mamba-1 was motivated by an SSM-centric point of view where the selective SSM layer is viewed as a map from 𝑋 → 𝑌 )
SSM参数, 𝐵, 𝐶被视为辅助参数,是SSM输入 𝑋的函数。 因此,定义(𝐴, 𝐵, 𝐶)的线性投影——在初始线性投影创建𝑋之后进行(The SSM parameters 𝐴, 𝐵, 𝐶 are viewed as subsidiary and are functions of the SSM input 𝑋 . Thus the linear projections defining (𝐴, 𝐵, 𝐶) occur after the initial linear projection to create 𝑋) - 在Mamba-2中,SSD层被视为从𝐴, 𝑋, 𝐵, 𝐶 → 𝑌的映射。 因此,有必要在块的开头通过单个投影并行生成𝐴, 𝑋,𝐵, 𝐶(In Mamba-2, the SSD layer is viewed as a map from 𝐴, 𝑋, 𝐵, 𝐶 ↦ → 𝑌 . It therefore makes sense to produce 𝐴, 𝑋, 𝐵, 𝐶 in parallel with a single projection at the beginning of the block) 值得注意的是
这与标准注意力架构类比,其中𝑋, 𝐵, 𝐶对应于并行创建的𝑄, 𝐾, 𝑉投影(Note the analogy to standard attention architectures, where 𝑋, 𝐵, 𝐶 correspond to the 𝑄, 𝐾, 𝑉 projections that are created in parallel.)
为SSM的𝐴, 𝐵, 𝐶, 𝑋输入采用并行投影略微减少了参数,更重要的是,通过使用标准的Megatron分片模式,更适合于较大模型的张量并行(Note that adopting parallel projections for the 𝐴, 𝐵, 𝐶, 𝑋 inputs to the SSM slightly reduces parameters and more importantly is more amenable to tensor parallelism for larger models, by using standard Megatron sharding patterns)
3.2.1.2 额外的归一化
在初步实验中,发现较大模型中容易出现不稳定性
通过在最终输出投影之前的块中添加一个额外的归一化层(例如LayerNorm、GroupNorm或RMSNorm)来缓解这一问题。 这种归一化的使用与NormFormer架构最直接相关,该架构也在MLP和MHA块的末端添加了归一化层
且mamba2的作者还注意到,这一变化类似于其他最近与Mamba-2相关的模型,这些模型是从线性注意力视角推导出来的
- 原始的线性注意力公式通过一个分母项进行归一化,该分母项模拟了标准注意力中softmax函数的归一化
而TransNormerLLM和RetNet发现这种归一化是不稳定的,并在线性注意力层之后添加了额外的LayerNorm或GroupNorm - mamba2的额外归一化层与这些略有不同,发生在乘法门(multiplicative gate)分支之后而不是之前
Our extra normalization layer differs slightly from these, occuring after the multiplicative gate branch instead of before
3.2.2 序列变换的多头模式:多查询、多键、多值
回想一下,SSM被定义为一个序列变换
其中:
- 𝐴, 𝐵, 𝐶 参数具有状态维度 N
- 它们定义了一个序列变换,例如可以表示为矩阵
- 该变换作用于输入序列,独立于 P轴
可以将其视为定义了序列变换的一个 head
定义 7.1(多头模式) 多头序列变换由 H个独立的头组成,总模型维度为 D = d_model。参数可以在各头之间共享,形成一个head模式
- 状态大小 N和头维度 P类似于注意力机制中的 头维度和 头维度(The state size N and head dimension P are analogous to the 𝑄𝐾 head dimension and 𝑉 head dimension of attention, respectively)
- 正如在现代Transformer架构中(比如Google的PaLM、Meta的Llama),在Mamba-2中我们通常选择这些常数为64或128;当模型维度 D增加时,我们增加头的数量,同时保持头维度 N和 P不变(when the model dimension D increases, we increase the number of heads while keeping the head dimensions N and P fixed)
为了描述如何做到这一点,我们可以从多头注意力中转移和推广想法,以定义SSM或任何一般序列变换的类似模式(in order to describe how to do this, we can transfer and generalize ideas from multihead attention to define similar patterns for SSMs, or any general sequence transformation)
- 多头状态空间模型 (MHS) / 多头注意力机制 (MHA) 模式
Multihead SSM (MHS) / Multihead Attention (MHA) Pattern
经典的 MHA 模式假设头维度 P可以整除模型维度 D
头的数量定义为 H = D/P(比如transformer论文中,模型维度512,8个头,每个头的维度为512/8 = 64),然后,通过创建 H个核心序列变换的副本,通过创建每个参数的 H个独立副本来实现
请注意,虽然MHA模式最初是为注意力序列变换描述的,但它可以应用于与定义2.1兼容的任何事物。例如,多头SSD层将接受形状符合方程(17)的输入,其中SSD算法在 H = n_heads维度上广播 - Multi-contract SSM (MCS)/多查询注意力(MQA)模式
Multi-contract SSM (MCS) / Multi-query Attention (MQA) Pattern
多查询注意力(详见此文:一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA),顾名思义,即多个query 单个key value,如下图最右侧所示:Multi-query 可以显著提高自回归推理的速度,这依赖于缓存𝐾和𝑉张量。 这种技术只是避免给𝐾和𝑉额外的头维度,换句话说,就是将(𝐾, 𝑉)的单个头广播到𝑄的所有头上
利用状态空间对偶性,我们可以将MQA的等效SSM版本定义为方程(18) 其中, 𝑋 和 𝐵(注意力的 𝑉 和 𝐾 的SSM类比)在 H个头之间共享,也称之为多收缩SSM (MCS)头模式,因为控制SSM状态收缩的 𝐶 参数在每个头中都有独立的副本
相当于X B C类比于V K Q
此外,多查询注意力的思想可以扩展到分组查询注意力(分组头模式Grouped Head Patterns):而不是1个K和V头,可以创建 G个独立的K和V头,其中1 < G且 G整除 H(如上图中部所示)
这既是为了弥合多查询和多头注意力之间的性能差异,也是为了通过将 G设置为分片数量的倍数来实现更高效的张量并行 - 多键注意力 (MKA) 或多扩展SSM (MES)头模式
其中控制SSM扩展的 𝐵在每个头中是独立的,而 𝐶和 𝑋在头之间共享 - 多输入SSM (MIS) / 多值注意力(MVA) 模式
Multi-input SSM (MIS) / Multi-value Attention (MVA) Pattern
虽然MQA对于注意力来说是有意义的,因为它有KV缓存,但它不是SSM的自然选择
在Mamba中, 𝑋被视为SSM的主要输入,因此 𝐵和 𝐶是跨输入通道共享的参数,而在公式(20)中定义了一种新的多值注意力 (MVA) 的多输入 SSM (MIS) 模式,这同样可以应用于任何序列变换,例如 SSD
上面的描述可能比较绕,我给大家画个图,便一目了然了
首先,对于下图三种模式中的C B X都是可以逐一和注意力中的Q K V对应的,且当某个模式中的或C、或B、或X被圈起来时,则代表它的数量是更多的 属于多个,而没被圈起来的则可能是单个
具体而言,可以简单粗暴的理解为:
- 多查询便是多个Query 单个Key 单个Value
相当于对应:多个C 单个B 单个X- 多键便是多个Key 单个Query 单个Value
相当于对应:多个B 单个C 单个X- 多值便是多个Value 单个Query 单个Key
相当于对应:多个X 单个C 单个B
定义7.2 mamba1的重新定义
Mamba 架构的选择性SSM(S6)层可以被视为具有
- 头维度 𝑃 = 1: 每个通道都有独立的 SSM 动态 𝐴
- 多输入SSM(MIS) 或多值注意力(MVA)头结构(如上图最右侧所示):输入𝑋的所有通道共享𝐵、𝐶矩阵(对应于注意力对偶中的𝐾、Q)
因为通过实验证明,Mamba中最初使用的MVA模式表现最佳
此外,值得一提的是,Mamba-2中使用的多输入SSM头模式(multi-input SSM head pattern,比如8个X 1个C 一个B),可以轻松扩展到分组输入SSM(grouped-input SSM,GIS,比如8个X 4个C 4个B),或同义的分组值注意力(grouped-value attention,GVA,还是value对应的X最多,然后 C B相对少)
3.2.3 线性注意力的其他SSD扩展
// 待更
3.3 SSM的系统优化:张量并行、序列并行、可变长度
3.3.1 张量并行Tensor Parallel
张量并行「Tensor parallelism,简称TP,详见此文《大模型并行训练指南:通俗理解Megatron-DeepSpeed之模型并行与数据并行》的第二部分 张量并行(Tensor Parallelism,算模型并行的一种)」是一种模型并行技术,它将每一层(例如,注意力机制,MLP)拆分在多个加速器(如 GPU)上运行
- 这种技术被广泛用于在 GPU 集群上训练大多数大型模型(Brown 等,2020;Chow dhery 等,2023;Touvron, Lavril 等,2023;Touvron, L. Martin 等,2023)
其中每个节点通常有 4-8 个 GPU,并具有快速网络连接,如 NVLink - TP 最初是为 Transformer 架构开发的,没法直接适应其他架构,故在Mamba 架构中使用 TP 有一定的挑战,进一步,Mamba-2 架构用起来TP之后,还得考虑如何设计以使 TP 高效
回顾 Mamba 架构,单个输入(为简单起见,不进行批处理),输入投影矩阵,其中 是扩展因子(通常为2),输出投影矩阵
使用 TP,假设想将计算分配到 2 个 GPU 上
- 很容易将输入投影矩阵和分成两个大小为的分区
It is easy to split the input projection matrices 𝑊 (𝑥 ) and 𝑊 (𝑧 ) into two partitions each of size 𝑑 × 𝑒𝑑/2 - 然后每个 GPU 将持有大小为的一半
Then each GPU would hold half of 𝑥𝑐 of size 𝐿 × 𝑒𝑑/2 - 然而,由于 Δ, 𝐵, 𝐶是的函数,所以需要在 GPU 之间进行额外的全归约,以在计算Δ, 𝐵, 𝐶之前获得整个
However,we see that since Δ, 𝐵, 𝐶 are functions are 𝑥𝑐 , so we would need an extra all-reduce between the GPUs to get the whole of 𝑥𝑐 before computing Δ, 𝐵, 𝐶 - 之后,由于它们在𝑑上是独立的,因此两个 GPU 可以并行计算 SSM
After that the two GPUs can compute the SSM in parallel since they are independent
along 𝑑 - 最后,我们可以将输出投影矩阵分成两个大小为的分区,并在最后进行一次全规约
At the end, we can split the output projection matrices 𝑊 (𝑜 ) into two partitions each of size 𝑒𝑑/2 × 𝑑, and do an all-reduce at the end
上述整个过程,与Transformer相比,将进行两次全规约,而不是一次,从而使通信时间加倍(Compared to Transformers, we would incur two all-reduces instead of one, doubling the time spent in communication)
对于大规模Transformer训练,通信可能已经占用了相当大的一部分时间(例如10-20%),加倍通信将使Mamba在大规模训练中效率不高「For large-scale Transformers training, communication might already take a significant fraction of time(e.g. 10-20%), and doubling communication would make Mamba not as efficient for large-scale training」
使用Mamba-2的目标是每个块只有一次全规约,类似于Transformer中的注意力或MLP块。因此,通过投影直接从𝑢得到Δ, 𝐵, 𝐶,而不是从得到,从而允许拆分这些投影矩阵
这意味着我们在不同的GPU上有不同的 Δ, 𝐵, 𝐶集合,这相当于在一个更大的“逻辑GPU”上有几个“组”的 Δ, 𝐵, 𝐶。此外,在每个块内使用GroupNorm,组的数量可被TP度整除,这样TP组中的GPU在块内无需通信:
可以看到,只需要拆分输入投影矩阵和输出投影矩阵,并且只需要在块的末尾进行全归约。 这类似于注意力和MLP层的TP设计
特别地,如果有TP度为2,则会拆分
- ,其中
- ,其中
- ,其中
对于 𝑖 = 1, 2,TP Mamba- 2层可以写成
总之,如下图所示
- 左侧是张量并行,分割输入投影矩阵、和输出投影矩阵
每个SSM头 (𝐴, 𝐵, 𝐶, 𝑋) →𝑌存在于单个设备上,选择GroupNorm作为最终归一化层可以避免额外的通信。每层需要一次全归约,就像Transformer中的MLP或注意力块一样 - 右侧是序列/上下文并行,类似于SSD算法,使用多个设备,可以沿序列维度进行分割,每个设备计算其序列的状态,然后将该状态传递给下一个GPU
3.3.2 序列并行
对于非常长的序列,可能需要沿序列长度维度将输入和激活拆分到不同的GPU上。 有两种主要技术:
- 用于残差和归一化操作的序列并行(SP):由Korthikanti等人首次提出,这种技术将TP中的all-reduce分解为reduce-scatter和all-gather
注意到在同一TP组中的所有GPU上,残差和归一化操作在相同输入上重复进行,SP通过执行:reduce-scatter、残差和归一化,然后all-gather,沿序列长度维度拆分激活
由于Mamba-2架构使用相同的残差和归一化结构,SP无需修改即可应用 - 序列并行用于token混合操作(注意力或SSM),也称为“上下文并行”(context parallelism,简称CP)。已经为注意力层开发了几种技术「例如,环形注意力(Liu, Yan, et al. 2024; Liu, Zaharia和 Abbeel 2023),使用复杂的负载均衡技术(Brandon 等人,2023)
注意力机制中的序列并行问题在于可以将查询和键分成块,但每个查询块需要与键块交互,导致通信带宽与工作者数量呈二次方关系
使用 SSMs,可以简单地分割序列:每个工作者获取一个初始状态,计算其输入的 SSM,返回最终状态,并将最终状态传递给下一个工作者。 通信带宽与工作者数量呈线性关系。 这种分解与 SSD 算法中的块分解完全相同,可以分成块/块
且在上图 中说明了这种上下文并行性
3.3.3 可变长度
虽然预训练通常对批次使用相同的序列长度,但在微调或推理过程中,模型可能需要处理不同长度的输入序列。
一种处理这种情况的简单方法是将批处理中所有序列右填充到最大长度,但如果序列长度差异很大,这可能效率低下。 对于Transformer,已经开发了复杂的技术来避免填充,并在GPU之间进行负载平衡(Zeng等,2022;Y.Zhai等,2023),或者在同一批次中打包多个序列并调整注意力掩码(Ding等,2024;Pouransari等,2024)
对于SSM,特别是Mamba,可以通过简单地将整个批次视为一个长序列来处理可变序列长度,并避免在单个序列之间传递状态。 这相当于简单地设置,对于一个序列末尾的token 𝑡,以防止它将信息传递给属于不同序列的token 𝑡 + 1
// 待更