重温被Mamba带火的SSM:线性系统和HiPPO矩阵

9270d069b09b5823e4c9c51c5412252c.png

本文约4800字,建议阅读10分钟
本文介绍了ssm相关内容。

前几天,笔者看了几篇介绍 SSM(State Space Model)的文章,才发现原来自己从未认真了解过 SSM,于是打算认真去学习一下 SSM 的相关内容,顺便开了这个新坑,记录一下学习所得。

SSM 的概念由来已久,但这里我们特指深度学习中的 SSM,一般认为其开篇之作是 2021 年的 S4,不算太老,而 SSM 最新最火的变体大概是去年的 Mamba [1]。

当然,当我们谈到 SSM 时,也可能泛指一切线性 RNN 模型,这样 RWKV [2]、RetNet [3] 、LRU 都可以归入此类。不少 SSM 变体致力于成为 Transformer 的竞争者,尽管笔者并不认为有完全替代的可能性,但 SSM 本身优雅的数学性质也值得学习一番。

尽管我们说 SSM 起源于 S4,但在 S4 之前,SSM 有一篇非常强大的奠基之作《HiPPO: Recurrent Memory with Optimal Polynomial Projections》[4](简称 HiPPO),所以本文从 HiPPO 开始说起。

01 基本形式

先插句题外话,上面提到的 SSM 代表作 HiPPO、S4、Mamba 的一作都是 Albert Gu [5],他还有很多篇 SSM 相关的作品,毫不夸张地说,这些工作筑起了 SSM 大厦的基础。不论 SSM 前景如何,这种坚持不懈地钻研同一个课题的精神都值得我们由衷地敬佩。

言归正传。对于事先已经对 SSM 有所了解的读者,想必知道 SSM 建模所用的是线性 ODE 系统:

83bb5c153d5bedf7bedc97038d4e174e.png

其中:

5e14534061e592e8db6aee1af1833389.png

当然我们也可以将它离散化,那么就变成一个线性 RNN 模型,这部分我们在后面的文章再展开。不管离散化与否,其关键词都是“线性”,那么马上就有一个很自然的问题:为什么是线性系统?线性系统够了吗?

我们可以从两个角度回答这个问题:线性系统既足够简单,也足够复杂。简单是指从理论上来说,线性化往往是复杂系统的一个最基本近似,所以线性系统通常都是无法绕开的一个基本点;复杂是指即便如此简单的系统,也可以拟合异常复杂的函数,为了理解这一点,我们只需要考虑一个  的简单例子:

92999733c9b3f6185bc2fddb0699f24b.png

这个例子的基本解是 。这意味着什么呢?意味着只要  足够大,该线性系统就可以通过指数函数和三角函数的组合来拟合足够复杂的函数,而我们知道拟合能力很强的傅里叶级数也只不过是三角函数的组合,如果在加上指数函数显然就更强了,因此可以想象线性系统也有足够复杂的拟合能力。

当然,这些解释某种意义上都是“马后炮”。HiPPO 给出的结果更加本质:当我们试图用正交基去逼近一个动态更新的函数时,其结果就是如上的线性系统。这意味着,HiPPO 不仅告诉我们线性系统可以逼近足够复杂的函数,还告诉我们怎么去逼近,甚至近似程度如何。‍

02 有限压缩

接下来,我们都只考虑  的特殊情形, 只不过是  时的平行推广。此时, 的输出是一个标量,进一步地,作为开头我们先假设 ,HiPPO 的目标是:用一个有限维的向量来储存这一段  的信息。

看上去这是一个不大可能的需求,因为  意味着  可能相当于无限个点组成的向量,压缩到一个有限维的向量可能严重失真。不过,如果我们对  做一些假设,并且允许一些损失,那么这个压缩是有可能做到的,并且大多数读者都已经尝试过。比如,当  在某点  阶可导的,它对应的  阶泰勒展开式往往是  的良好近似,于是我们可以只储存展开式的  个系数来作为  的近似表征,这就成功将  压缩为一个  维向量。

当然,对于实际遇到的数据来说,“阶可导”这种条件可谓极其苛刻,我们通常更愿意使用在平方可积条件下的正交函数基展开,比如傅里叶(Fourier)级数,它的系数计算公式为

69da810e60c8de760e8896c9548381f7.png

这时候取一个足够大的整数 ,只保留  的系数,那么就将  压缩为一个  维的向量了。

接下来,问题难度就要升级了。刚才我们说 ,这是一个静态的区间,而实际中  代表的是持续采集的信号,所以它是不断有新数据进入的,比如现在我们近似了  区间的数据,马上就有  的数据进来,你需要更新逼近结果来试图记忆整个  区间,接下来是 、 等等,这我们称为“在线函数逼近”。而上面的傅里叶系数公式(3),只适用于区间 ,因此需要将它进行推广。

为此,我们设 , 是  到  的一个映射,那么  作为  的函数时,它的定义区间就是 ,于是就可以复用式(3):

2459b27f2d929779a4231ac813f8592a.png

这里我们已经给系数加了标记 ,以表明此时的系数会随着  的变化而变化。

03 线性初现

能将  映射到  的函数有无穷多,而最终结果也因  而异,一些比较直观且相对简单的选择如下:

1、,即将  均匀地映射到 ;

2、注意  并不必须是满射,所以像  也是允许的,这意味着只保留了最邻近窗口  的信息,丢掉了更早的部分,更一般地有 ,其中  是一个常数,这意味着  前的信息被丢掉了;

3、也可以选择非均匀映射,比如 ,它同样是  到  的满射,但  时就映射到  了,这意味着我们虽然关注全局的历史,但同时更侧重于T时刻附近的信息。

现在我们以  为例,代入式(4)得到

1349fb1dafa2c71ff3f2640cf31f1625.png

现在我们两边求关于  的导数:

5b6591ca2511f779306ad1151e7bb74a.png

其中第二个等号我们用了分部积分公式。由于我们只保留了  的系数,所以根据傅立叶级数的公式,可以认为如下是  的一个良好近似:

ffe919426e985aa2e488fa0d06a689b4.png

那么:

419061c43049569b33685267607c6922.png

代入式(6)得:

c069f09c3db9ee83348b47dd2d12aeb5.png

将  换成 ,然后所有的  堆在一起记为 ,并且不区分  和 ,那么就可以写出

192dec13680a82bb205a7abf7516d7a7.png

这就出现了如式(1)所示的线性 ODE 系统。即当我们试图用傅里叶级数去记忆一个实时函数的最邻近窗口内的状态时,结果自然而言地导致了一个线性 ODE 系统。

04 一般框架

当然,目前只是选择了一个特殊的 ,换一个  就不一定有这么简单的结果了。此外,傅里叶级数的结论是在复数范围内的,进一步实数化也可以,但形式会变得复杂起来。所以,我们要将上一节的过程推广成一个一般化的框架,从而得到更一般、更简单的纯实数结论。

设 ,并且有目标函数  和函数基 ,我们希望有后者的线性组合来逼近前者,目标是最小化  距离:

e49ff2cdb3f2f2492c2e60a4a7748ea6.png

这里我们主要在实数范围内考虑,所以方括号直接平方就行,不用取模。更一般化的目标函数还可以再加个权重函数 ,但我们这里就不考虑了,毕竟 HiPPO 的主要结论其实也没考虑这个权重函数。

对目标函数展开,得到

a06166be687ecb8146a1abd26d5c81de.png

这里我们只考虑标准正交函数基,其定义为 , 是克罗内克 δ 函数 [6],此时上式可以简化成

63bbc4dc58f7496709435f9621a730a0.png

这只是一个关于  的二次函数,它的最小值是有解析解的:

917fbc46845873148187dc15bb1c3448.png

这也被称为  与  的内积,它是有限维向量空间的内积到函数空间的平行推广。简单起见,在不至于混淆的情况下,我们默认  就是 。

接下来的处理跟上一节是一样的,我们要对一般的  考虑  的近似,那么找一个  到  的映射 ,然后计算系数

e431ff9e7689b0efee5e4fcb694c4b84.png

同样是两边求 T 的导数,然后用分部积分法

f4eeee1476c90eabb03659c6b3da61b1.png

05 请勒让德

接下来的计算,就依赖于  和  的具体形式了。HiPPO 的全称是 High-order Polynomial Projection Operators,第一个 P 正是多项式(Polynomial)的首字母,所以 HiPPO 的关键是选取多项式为基。现在我们请出继傅里叶之后又一位大牛——勒让德(Legendre),接下来我们要选取的函数基正是以他命名的“勒让德多项式” [7]。

勒让德多项式  是关于  的  次函数,定义域为 [-1,1],满足

886246ba9525a5a38ca4be829cc6ea60.png

所以  之间只是正交,还不是标准(平分积分为 1), 才是标准正交基。

当我们对函数基  执行施密特正交化 [8] 时,其结果正是勒让德多项式。相比傅里叶基,勒让德多项式的好处是它是纯粹定义在实数空间中的,并且多项式的形式能够有助于简化部分  的推导过程,这一点我们后面就可以看到。勒让德多项式有很多不同的定义和性质,这里我们不一一展开,有兴趣的读者自行看链接中维基百科介绍即可。

接下来我们用到两个递归公式来推导一个恒等式,这两个递归公式是

0828be420418a1d54db8743e2288d7ba.png

由第一个公式(17)迭代得到:

be676fdcb153bccb352c383489dd0f68.png

其中当  是偶数时  否则 。代入第二个公式(17)得到

02d61827f37a6dff45d26833c8287920.png

继而有

66f5afab129460cb8c2e157f43643fa1.png

这些就是等会要用到的恒等式。此外,勒让德多项式满足 ,这个边界值后面也会用到。

正如  维空间中不止有一组正交基也一样,正交多项式也不止有勒让德多项式一种,比如还有切比雪夫(Chebyshev)多项式 [9],如果算上加权的目标函数(即 ),还有拉盖尔多项式 [10] 等,这些在原论文中都有提及,但 HiPPO 的主要结论还是基于勒让德多项式展开的,所以剩余部分这里也不展开讨论了。

06 邻近窗口

完成准备工作后,我们就可以代入具体的  进行计算了,计算过程跟傅里叶级数的例子大同小异,只不过基函数换成了勒让德多项式构造的标准正交基 。作为第一个例子,我们同样先考虑只保留最邻近窗口的信息,此时  将  映射到 ,原论文将这种情形称为“LegT(Translated Legendre)”。

直接代入式(15),马上得到

64ebc9331c95cac69b8f56d8566bc0b4.png

我们首先处理  项,跟傅里叶级数那里同样的思路,我们截断  作为  的一个近似:

5ac984a9d3a6d5ec73c7329bcf8e8fc2.png

从而有

 。

接着,利用式(19)得到

2a082fde1785001f632238aef5cb14f1.png

将这些结果整合起来,就有

a704aa695c6f274a45d9b269bf39a3b0.png

再次地,将  换回 ,并将所有的  堆在一起记为 ,那么根据上式可以写出

f3f7f04006f097384c8a1c541a87f14b.png

我们还可以给每个  都引入一个缩放因子,来使得上述结果更一般化。比如我们设 ,代入式(25)整理得

a7aaa741f863de3ae32a112070f1fb58.png

如果取 ,那么  不变,,这就对齐了原论文的结果,如果取 ,那么就得到了 Legendre Memory Units [11] 中的结果

7266c01cbeaa9c36b5a170392fb6d3fc.png

这些形式在理论上都是等价的,但可能存在不同的数值稳定性。比如一般来说当  的性态不是特别糟糕时,我们可以预期  越大, 的值就相对越小,这样直接用  的话  向量的每个分量的尺度就不大对等,这样的系统在实际计算时容易出现数值稳定问题,而取  改用  的话意味着数值小的分量会被适当放大,可能有助于缓解多尺度问题从而使得数值计算更稳定。

07 整个区间

现在我们继续计算另一个例子:,它将  均匀映射到 ,这意味着我们没有舍弃任何历史信息,并且平等地对待所有历史,原论文将这种情形称为 “LegS(Scaled Legendre)”。

同样地,通过代入式(15)得到

e78b862b179ceccfb7e1eaaae8e1e961.png

利用公式(21)得到

eed837a7c496e1efdb55af9ab51a8e49.png

于是有

2cbe16e43167824ddae4abaf917e3faf.png

将  换回 ,将所有的  堆在一起记为 ,那么根据上式可以写出

6a96b526c513e6e04170ebd57f1ce6fc.png

引入缩放因子来一般化结果也是可行的:设 ,代入式(25)整理得

4e6c05299f8f6d7e0d03974fea01557a.png

取  就可以让  不变, 变为 ,就对齐了原论文的结果。如果取 ,就可以像上一节 LegT 的结果一样去掉根号

e9a4432597e9b2d6438aab9639328485.png

但原论文没有考虑这种情况,原因不详。

08 延伸思考

回顾 Leg-S 的整个推导,我们可以发现其中关键一步是将  拆成  的线性组合,对于正交多项式来说, 是一个  次多项式,所以这种拆分必然可以精确成立,但如果是傅立叶级数的情况, 是指数函数,此时类似的拆分做不到了,至少不能精确地做到,所以可以说选取正交多项式为基的根本目的是简化后面推导。

特别要指出的是,HiPPO 是一个自下而上的框架,它并没有一开始就假设系统必须是线性的,而是从正交基逼近的角度反过来推出其系数的动力学满足一个线性 ODE 系统,这样一来我们就可以确信,只要认可所做的假设,那么线性 ODE 系统的能力就是足够的,而不用去担心线性系统的能力限制了你的发挥。

当然,HiPPO 对于每一个解所做的假设及其物理含义也很清晰,所以对于重用了 HiPPO 矩阵的 SSM,它怎么储存历史、能储存多少历史,从背后的 HiPPO 假设就一清二楚。

比如 LegT 就是只保留  大小的最邻近窗口信息,如果你用了 LegT 的 HiPPO 矩阵,那么就类似于一个 Sliding Window Attention;而 LegS 理论上可以捕捉全部历史,但这有个分辨率问题,因为  的维度代表了拟合的阶数,它是一个固定值,用同阶的函数基去拟合另一个函数,肯定是区间越小越准确,区间越大误差也越大,这就好比为了一次性看完一幅大图,那么我们必须站得更远,从而看到的细节越少。

诸如 RWKV、LRU 等模型,并没有重用 HiPPO 矩阵,而是改为可训练的矩阵,原则上具有更多的可能性来突破瓶颈,但从前面的分析大致上可以感知到,不同矩阵的线性 ODE 只是函数基不同,但本质上可能都只是有限阶函数基逼近的系数动力学。既然如此,分辨率与记忆长度就依然不可兼得,想要记忆更长的输入并且保持效果不变,那就只能增加整个模型的体量(即相当于增加 hidden_size),这大概是所有线性系统的特性。

09 文章小结

本文以尽可能简单的方式重复了《HiPPO: Recurrent Memory with Optimal Polynomial Projections》[4](简称 HiPPO)的主要推导。HiPPO 通过适当的记忆假设,自下而上地导出了线性 ODE 系统,并且针对勒让德多项式的情形求出了相应的解析解(HiPPO 矩阵),其结果被后来诸多 SSM(State Space Model)使用,可谓是 SSM 的重要奠基之作。

参考文献

[1] https://papers.cool/arxiv/2312.00752

[2] https://papers.cool/arxiv/2305.13048

[3] https://papers.cool/arxiv/2307.08621

[4] https://papers.cool/arxiv/2008.07669

[5] https://dblp.org/pid/130/0612.html

[6] https://en.wikipedia.org/wiki/Kronecker_delta

[7] https://en.wikipedia.org/wiki/Legendre_polynomials

[8] https://en.wikipedia.org/wiki/Gram–Schmidt_process

[9] https://en.wikipedia.org/wiki/Chebyshev_polynomials

[10] https://en.wikipedia.org/wiki/Laguerre_polynomials

[11] https://proceedings.neurips.cc/paper/2019/file/952285b9b7e7a1be5aa7849f32ffff05-Paper.pdf

编辑:王菁

校对:林亦霖80ccd061f0564aaf15b117bb3d2cc613.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值