重温被Mamba带火的SSM:HiPPO的一些遗留问题

6d6575874b90da7adc085dae8aa1a24a.gif

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 科学空间

研究方向 | NLP、神经网络

书接上文,在上一篇文章《重温被Mamba带火的SSM:线性系统和HiPPO矩阵》中,我们详细讨论了 HiPPO 逼近框架其 HiPPO 矩阵的推导,其原理是通过正交函数基来动态地逼近一个实时更新的函数,其投影系数的动力学正好是一个线性系统,而如果以正交多项式为基,那么线性系统的核心矩阵我们可以解析地求解出来,该矩阵就称为 HiPPO 矩阵。

当然,上一篇文章侧重于 HiPPO 矩阵的推导,并没有对它的性质做进一步分析,此外诸如“如何离散化以应用于实际数据”、“除了多项式基外其他基是否也可以解析求解”等问题也没有详细讨论到。接下来我们将补充探讨相关问题。

72695fe7b48accf34f164ba0d20380f8.png

离散格式

假设读者已经阅读并理解上一篇文章的内容,那么这里我们就不再进行过多的铺垫。在上一篇文章中,我们推导出了两类线性 ODE 系统,分别是:

bcfb9b29d3391efee7e4a707cde8db76.png

其中 是与时间 无关的常数矩阵,HiPPO 矩阵主要指矩阵 。在这一节中,我们讨论这两个 ODE 的离散化。

a5b1acf24c70c7accdadd17a6a6e4fc4.png

输入转换

在实际场景中,输入的数据点是离散的序列 ,比如流式输入的音频信号、文本向量等,我们希望用如上的 ODE 系统来实时记忆这些离散点。为此,我们先定义

497b0bfbbc4c9982f01dd179bdbf6e5b.png

其中 就是离散化的步长。该定义也就是说在区间 内, 是一个常数函数,其值等于 。很明显这样定义出来的 无损原本 序列的信息,因此记忆 就相当于记忆 序列。

从 变换到 ,可以使得输入信号重新变回连续区间上的函数,方便后面进行积分等运算,此外在离散化的区间内保持为常数,也能够简化离散化后的格式。

205c65fc8cd01f2985d41286cfe0eca6.png

LegT版本

我们先以 LegT 型 ODE(1)为例,将它两端积分

39b14cb7a4c9c6c6ef8756fb97a7403a.png

其中 。根据 的定义,它在 区间内恒为 ,于是 的积分可以直接算出来:

e83b1cc1b83453f26dfa8e386f2df729.png

接下来的结果,就取决于我们如何近似 的积分了。假如我们认为在 区间内 近似恒等于 ,那么就得到前向欧拉格式

de2f8a3918a9115e16e85276fe5e6bd2.png

我们认为在 区间内 近似恒等于 ,那么就得到后向欧拉格式

18ff7d7ac945c7838410ee8a88f6e0e1.png

前后向欧拉都具有相同的理论精度,但后向通常会有更好的数值稳定性。如果要更准确一些,那么认为在 区间内 近似恒等于 ,那么得到双线性形式:

cc10a25ba01866cbdc4e9df66d6c620c.png

这也等价于先用前向欧拉走半步,再用后向欧拉走半步。更一般地,我们还可以认为在 区间内 近似恒等于 ,其中 ,这就不进一步展开了。事实上,我们也可以完全不做近似,因为结合式(1)以及在区间 中 是常数 ,我们完全可以用“常数变易法” [1] 来精确求解出来,结果是

20d3258b86483370cecf57e2e4a5274e.png

这里的矩阵指数按照级数来定义,可以参考《恒等式 det(exp(A)) = exp(Tr(A)) 赏析》[2]。

b5c3af781f5ef5aed39184f38ca012e9.png

LegS版本

现在轮到 LegS 型 ODE 了,它的思路跟 LegT 型基本一致,结果也大同小异。首先将式(2)两端积分得到

f56a5b582595da8803244bd1d4798e41.png

根据 定义,第二项积分的 在 恒为 ,所以它相当于 的积分,可以直接积分出来得 ,当然直接换为一阶近似 也无妨,因为本身 到 的变换有很大自由度,这点误差无所谓。至于第一项积分,我们直接采用精度更高的中点近似,得到

3d32358703cd100ba8d7888f7c7ca7ea.png

事实上,式(2)也可以精确求解,只需要留意到它等价于

5f2486c7b90f179dd7644aa7801ed435.png

这意味着只需要做变量代换 ,那么 LegS 型 ODE 就可以转化为 LegT 型 ODE:

b7453768ea99eb14bc3107a0f7b24d9c.png

利用式(9)得到(由于变量代换,时间间隔由 变成 )

3b27bf53d44cd44da1816ca3898cdfb2.png

然而,上式虽然是精确解,但不如同为精确解的式(9)好用,因为式(9)的指数矩阵部分是 ,跟时间 无关,所以一次性计算完就可以了。但上式中 在矩阵指数里边,意味着在迭代过程中需要反复计算矩阵指数,对计算并不友好,所以 LegS 型 ODE 我们一般只会用式(11)来离散化。

dd915812ddf968c8f6910152b6a5ea0e.png

优良性质

接下来,LegS 是我们的重点关注对象。重点关注 LegS 的原因并不难猜,因为从推导的假设来看,它是目前求解出来的唯一一个能够记忆整个历史的 ODE 系统,这对于很多场景如多轮对话来说至关重要。此外,它还有其他的一些比较良好且实用的性质。

66cb881449e2afc52397ee3bee3ac21c.png

尺度不变

比如,LegS 的离散化格式(11)是步长无关的,我们只需要将 代入里边,并记 ,就可以发现

d07afc73debd7b6753b67d553b9cd06c.png

步长 被自动地消去了,从而自然地减少了一个需要调的超参数,这对于炼丹人士显然是一个好消息。注意步长无关是 LegS 型 ODE 的一个固有性质,它跟具体的离散化方式并无直接关系,比如精确解(14)同样是步长无关的:

33176b7526f9e7e8e7c14d6c880ed106.png

其背后的原因,在于 LegS 型 ODE 满足“时间尺度不变性(Timescale equivariance)”——如果我们设 代入 LegS 型 ODE,将得到

09d50c22725e81159f251f43fb9a3447.png

这意味着,当我们将 换成 时,LegS 的 ODE 形式并没有变化,而对应的解则是 换成了 。这个性质的直接后果就是:当我们选择更大的步长时,递归格式不需要发生变化,因为结果 的步长也会自动放大,这就是 LegS 型 ODE 离散化与步长无关的本质原因。

06cec89424c67c4b5b9ea084f59c91fa.png

长尾衰减

LegS 型 ODE 的另一个优良性质是,它关于历史信号的记忆是多项式衰减(Polynomial decay)的,这比常规 RNN 的指数衰减更缓慢,从而理论上能记忆更长的历史,更不容易梯度消失。为了理解这一点,我们可以从精确解(16)出发,从式(16)可以看到,每递归一步,历史信息的衰减效应可以用矩阵指数 来描述,那么从第 步递归到第 步,总的衰减效应是

cf3627dbfd1464f267373f580baee770.png

回顾 HiPPO-LegS 中 的形式:

0231b5a5428604d9ff5ad1bfadbe5e5f.png

从定义可以看出, 是一个下三角阵,其对角线元素为 。我们知道,三角阵的对角线元素正好是它的特征值(参考 Triangular matrix [3]),由此可以看到一个 大小的 矩阵,有 个不同的特征值 ,这说明 矩阵是可对角化的,即存在可逆矩阵 ,使得 ,其中 ,于是我们有

6479e85bbf0def16844ccb125f1342de.png

可见,最终的衰减函数是 的 次函数的线性组合,所以 LegS 型 ODE 关于历史记忆至多是多项式衰减的,比指数衰减更加长尾,因此理论上有更好的记忆力。

f97458d8c84c36844cff119fda36ceee.png

计算高效

最后,我们指出 HiPPO-LegS 的 矩阵是计算高效(Computational efficiency)的。具体来说,直接按照矩阵乘法的朴素实现的话,一个 的矩阵乘以 的列向量,需要做 次乘法,但 LegS 的 矩阵与向量相乘则可以降低到 次,更进一步地,我们还可以证明离散化后的(11)也可以在 完成。

为了理解这一点,我们首先将 HiPPO-LegS 的 A 矩阵等价地改写成

871541e5b9ae62135b679f34a0d5e682.png

对于向量 ,我们有

51134aef6c120c7cf0303719454ad47e.png

这包含三种运算,第一项的 是向量 与 做逐位相乘运算,第二项的 则是向量 与 做逐位相乘,然后 就是 运算,最后乘以 就是再逐位相乘向量 ,每一步都可以在 内完成,因此总的复杂度是 的。

我们再来看(11),它包含两步“矩阵-向量”乘法运算,一是 , 是任意实数,刚才我们已经证明了 是计算高效的,自然 也是;二是 ,接下来我们将证明它也是计算高效的。这只需要留意到求 等价于解方程 ,利用上面给出的 表达式,我们可以得到

833c8512f17720e045f6c5f91d7cabc8.png

记 ,那么 ,代入上式得

a1ff87a0e93eae30d30163d67fa19bc9.png

整理得

c537920d06a538faa67995fcc0bf5951.png

这是一个标量的递归式,可以完全串行地计算,也可以利用 Prefix Sum 的相关算法并行计算(参考这里),计算复杂度为 或者 ,总之相比 都会更加高效。

4613540ea50271eb6abbcea31d6646ca.png

傅立叶基

最后,我们以傅立叶基的一个推导收尾。在上一篇文章中,我们以傅立叶级数来引出了线性系统,但只推导了邻近窗口形式的结果,而后面的勒让德多项式基我们则推导了邻近窗口和完整区间两个版本(即 LegT 和 LegS)。那么傅立叶基究竟能不能推导一个跟 LegS 相当的版本呢?其中会面临什么困难呢?下面我们对此进行探讨。

同样地,相关铺垫我们不再重复,按照上一节的记号,傅立叶基的系数为

d6cb1622f8dc96ed9793703e848eed00.png

跟 LegS 一样,为了记忆整个 区间的信号,我们需要一个 的映射,为此选取最简单的 ,代入后两边求导得到

7e28fd098d2f6853e2621766f82cb8fd.png

分部积分得到

4eb7f20411cd0522433f09ea297a1cb8.png

上一篇文章我们提到,HiPPO 选取勒让德多项式为基的重要原因之一是 可以分解为 的线性组合,而傅里叶基的 则不能做到这一点。但事实上,如果允许误差的话,这个断论是不成立的,因为我们同样可以将 分解为傅里叶级数:

2c288ee497fc14a0bfb889cd42028a41.png

这里的求和有无限项,如果要截断为有限项的话,就会产生误差,但我们可以先不纠结这一点,直接往上代入得到

0ca892c1072cfc2bdbab5e54112d5d12.png

这样一来

cdae54c93e963feeb535458638bcf780.png

所以可以写出

70c5b68a78b4d80deff97730fd433097.png

实际使用的时候,我们只需要截断 ,就可以得到一个 的矩阵。截断带来的误差其实是无所谓的,因为我们在推导 HiPPO-LegT 的时候同样引入了有限级数近似,那会我们同样也没考虑误差,或者反过来讲,对于特定的任务,我们会选择适当的规模(即 N 的大小),而这个“适当”的含义之一,就是截断带来误差对于该任务是可以忽略的。

对大多数人来说,傅立叶基的这个推导可能还更容易理解一些,因为勒让德多项式对很多读者来说都比较陌生,尤其是 LegT、LegS 推导过程中用到的几个恒等式,而对于傅立叶级数大多数读者应该或多或少都有所了解。

不过,从结果上来看,傅立叶基的这个结果可能不如 LegS 实用,一来它引入了复数,这增加了实现的复杂度,二来它推导出的 A 矩阵不像 LegS 那样是个相对较淡的下三角阵,因此理论分析起来也更为复杂。所以,大家权当它是一道深化对 HiPPO 的理解的练习题就好。

c446b4a23e472b16abbfc663f9a16acc.png

文章小结

在这篇文章中,我们补充探讨了上一篇文章介绍的 HiPPO 的一些遗留问题,其中包括如何对 ODE 进行离散化、LegS 型 ODE 的一些优良性质,以及利用傅立叶基记忆整个历史区间的结果推导(即 LegS 的傅立叶版本),以求获得对 HiPPO 的更全面理解。

outside_default.png

参考文献

outside_default.png

[1] https://en.wikipedia.org/wiki/Variation_ou_parameters

[2] https://kexue.fm/archives/6377#矩阵指数

[3] https://en.wikipedia.org/wiki/Triangular_matrix

更多阅读

2420ac57a87a14b1b1f88b3577eca5c7.png

ce5a24e5e570a2213574cdeebfcecd7d.png

30283b0e7f0198d912fe408138fbefdb.png

a94ba6b0fd7a7370cd3e74ac195b6d3c.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

44b30d72a0f89083c53e5078aecd1880.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

·

11fa5b4cf9f7140b42598f8fa482fe3b.jpeg

  • 27
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值