背景:Mamba的研究目标
1. 设计一个新的结构,性能不弱于Transformer,并具备更高的计算效率
大型语言模型(Large Language Models,LLMs)的成功离不开Transformer结构,然而,Transformer的推理效率较低。以文本生成为例,文本序列的生成是逐步进行的,Transformer每生成一个token,就需要和已有的tokens计算一次注意力。注意力的时间复杂度是 O ( n 2 ) O(n^2) O(n2),这意味着随着序列长度的增加,Transformer的计算量将呈平方级递增,导致难以处理很长的序列。
不过,Transformer的训练过程却很高效。与RNN不同,Transformer在处理输入序列时并不依赖于前一个时间步的输出。这意味着在训练过程中,整个序列的所有位置可以同时输入模型进行处理,无需按时间步顺序逐步计算。这使得Transformer可以在一个批次中并行处理整个输入序列。
Mamba的研究目标即保留Transformer的训练效率和生成精度的同时,提高模型的推理效率。具体来说,训练过程中,Mamba的时间和空间复杂度都是 O ( n ) O(n) O(n);而推理时可以采用自回归的方式,使得每一步的生成只需要常量的时间复杂度。
研究思路
1. 概述
作者发现状态空间模型(State Space Model,SSM)具有良好的特性,使其可以进行高效的训练和推理。然而,现有SSM性能还不够,因此设计了Selective State Space Model;并将其和MLP整合成一个block,多个blocks构成了Mamba结构。
2. 什么是SSM?
备注: 论文1中的SSM实际代指S4(Structured State Space Sequence models),本文也使用SSM代指S4。
结论: SSM本质上是RNN,区别是当有序列输入(多个时间步的输入)时,它能够以卷积的形式快速生成序列输出(每个时间步对应的输出),而非RNN中串行的方式。
简单回顾RNN: 下图展示了RNN的工作流程。在每个时间步t,网络接收当前输入 x t x_t xt和上一步的状态 h t − 1 h_{t-1} ht−1,得到当前步的输出 y t y_t yt。
当生成当前时间t的输出时,RNN只用考虑当前输入 x t x_t xt和上一步的状态 h t − 1 h_{t-1} ht−1,不需要和Transformer一样重新计算之前的状态(即注意力),因此可以高效推理,并且理论上拥有无限长的序列窗口。
然而,RNN在训练时只能串行处理,降低了训练效率。此外,在实际应用中,RNN往往倾向于记住和利用短期依赖信息,容易遗忘长程信息。
连续状态空间模型: SSM启发于连续状态空间模型,该模型被用于描述动态系统,被广泛应用于控制系统和信号处理等领域。和RNN相似,它也通过一个潜在状态 h ( t ) ∈ R N h(t)\in\mathbb{R}^N h(t)∈RN将输入序列 x ( t ) ∈ R x(t)\in\mathbb{R} x(t)∈R映射为输出序列 y ( t ) ∈ R y(t)\in\mathbb{R} y(t)∈R。连续状态空间模型的公式为:
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) \begin{align} h'(t)&=\mathbf{A}h(t)+\mathbf{B}x(t) \tag{1a} \\ y(t)&=\mathbf{C}h(t) \tag{1b} \end{align} h′(t)y(t)=Ah(t)+Bx(t)=Ch(t)(1a)(1b)
其中公式(1a)是状态转移方程,公式(1b)是输出方程; A \mathbf{A} A, B \mathbf{B} B和 C