Mamba方法精读

背景:Mamba的研究目标

1. 设计一个新的结构,性能不弱于Transformer,并具备更高的计算效率

大型语言模型(Large Language Models,LLMs)的成功离不开Transformer结构,然而,Transformer的推理效率较低。以文本生成为例,文本序列的生成是逐步进行的,Transformer每生成一个token,就需要和已有的tokens计算一次注意力。注意力的时间复杂度是 O ( n 2 ) O(n^2) O(n2),这意味着随着序列长度的增加,Transformer的计算量将呈平方级递增,导致难以处理很长的序列。

不过,Transformer的训练过程却很高效。与RNN不同,Transformer在处理输入序列时并不依赖于前一个时间步的输出。这意味着在训练过程中,整个序列的所有位置可以同时输入模型进行处理,无需按时间步顺序逐步计算。这使得Transformer可以在一个批次中并行处理整个输入序列。

Mamba的研究目标即保留Transformer的训练效率和生成精度的同时,提高模型的推理效率。具体来说,训练过程中,Mamba的时间和空间复杂度都是 O ( n ) O(n) O(n);而推理时可以采用自回归的方式,使得每一步的生成只需要常量的时间复杂度。

研究思路

1. 概述

作者发现状态空间模型(State Space Model,SSM)具有良好的特性,使其可以进行高效的训练和推理。然而,现有SSM性能还不够,因此设计了Selective State Space Model;并将其和MLP整合成一个block,多个blocks构成了Mamba结构。

2. 什么是SSM?

备注: 论文1中的SSM实际代指S4(Structured State Space Sequence models),本文也使用SSM代指S4。

结论: SSM本质上是RNN,区别是当有序列输入(多个时间步的输入)时,它能够以卷积的形式快速生成序列输出(每个时间步对应的输出),而非RNN中串行的方式。

简单回顾RNN: 下图展示了RNN的工作流程。在每个时间步t,网络接收当前输入 x t x_t xt和上一步的状态 h t − 1 h_{t-1} ht1,得到当前步的输出 y t y_t yt
在这里插入图片描述
当生成当前时间t的输出时,RNN只用考虑当前输入 x t x_t xt和上一步的状态 h t − 1 h_{t-1} ht1,不需要和Transformer一样重新计算之前的状态(即注意力),因此可以高效推理,并且理论上拥有无限长的序列窗口。

然而,RNN在训练时只能串行处理,降低了训练效率。此外,在实际应用中,RNN往往倾向于记住和利用短期依赖信息,容易遗忘长程信息。

连续状态空间模型: SSM启发于连续状态空间模型,该模型被用于描述动态系统,被广泛应用于控制系统和信号处理等领域。和RNN相似,它也通过一个潜在状态 h ( t ) ∈ R N h(t)\in\mathbb{R}^N h(t)RN将输入序列 x ( t ) ∈ R x(t)\in\mathbb{R} x(t)R映射为输出序列 y ( t ) ∈ R y(t)\in\mathbb{R} y(t)R。连续状态空间模型的公式为:
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) \begin{align} h'(t)&=\mathbf{A}h(t)+\mathbf{B}x(t) \tag{1a} \\ y(t)&=\mathbf{C}h(t) \tag{1b} \end{align} h(t)y(t)=Ah(t)+Bx(t)=Ch(t)(1a)(1b)
其中公式(1a)是状态转移方程,公式(1b)是输出方程; A \mathbf{A} A B \mathbf{B} B C \mathbf{C} C是参数矩阵。由于参数矩阵的值不随时间步变化,且无非线性操作,因此上述模型是线性时不变的(linear time invariance,LTI)。

离散化: 上述公式无法直接应用到离散的时间步上。为了获得 h t h_t ht h t − 1 h_{t-1} ht1之间的关系,作者采用零阶保持法(Zero-Order Hold,ZOH)离散化上述公式,得到以下结果:
h t = A ˉ h t − 1 + B ˉ x t y t = C h t A ˉ = e x p ( Δ A ) B ˉ = Δ A − 1 ( e x p ( Δ A ) − I ) ⋅ Δ B \begin{align} h_t&=\mathbf{\bar{A}}h_{t-1}+\mathbf{\bar{B}}x_t \tag{2a} \\ y_t&=\mathbf{C}h_t \tag{2b} \\ \mathbf{\bar{A}}&=exp(\mathbf{\Delta A}) \quad \mathbf{\bar{B}}=\mathbf{\Delta A}^{-1}(exp(\mathbf{\Delta A})-\mathbf{I}) \cdot \mathbf{\Delta B} \tag{3} \end{align} htytAˉ=Aˉht1+Bˉxt=Cht=exp(ΔA)Bˉ=ΔA1(exp(ΔA)I)ΔB(2a)(2b)(3)
其中 Δ \mathbf{\Delta} Δ也是可学习参数。这部分不理解不要紧

通过卷积快速得到每一步的输出: 因为SSM是LTI,所以可以得到每一步 y t y_t yt的closed-form计算公式。举个例子,假设我想要计算 y 3 y_3 y3
y 3 = ( A ˉ h 2 + B ˉ x 3 ) ⋅ C y 3 = ( A ˉ ( A ˉ h 1 + B ˉ x 2 ) + B ˉ x 3 ) ⋅ C y 3 = C B ˉ x 3 + C B ˉ A ˉ x 2 + C B ˉ A ˉ 2 x 1 y k = ∑ i = 1 k C B ˉ A ˉ k − i x i \begin{align} y_3&=(\mathbf{\bar{A}}h_{2}+\mathbf{\bar{B}}x_3) \cdot \mathbf{C} \tag{4} \\ y_3&=(\mathbf{\bar{A}}(\mathbf{\bar{A}}h_{1}+\mathbf{\bar{B}}x_2)+\mathbf{\bar{B}}x_3) \cdot \mathbf{C} \tag{5} \\ y_3&=\mathbf{C}\mathbf{\bar{B}}x_3+\mathbf{C}\mathbf{\bar{B}}\mathbf{\bar{A}}x_2+\mathbf{C}\mathbf{\bar{B}}\mathbf{\bar{A}}^2x_1 \tag{6} \\ y_k&=\sum_{i=1}^{k} \mathbf{C}\mathbf{\bar{B}}\mathbf{\bar{A}}^{k-i}x_i \tag{7} \end{align} y3y3y3yk=(Aˉh2+Bˉx3)C=(Aˉ(Aˉh1+Bˉx2)+Bˉx3)C=CBˉx3+CBˉAˉx2+CBˉAˉ2x1=i=1kCBˉAˉkixi(4)(5)(6)(7)
可以发现,公式(7)和卷积的形式相同。假设输出序列 y = [ y k , . . . , y 1 ] y=[y_k,...,y_1] y=[yk,...,y1],输入序列 x = [ x k , . . . , x 1 ] x=[x_k,...,x_1] x=[xk,...,x1],公式(7)又可写为:
K ˉ = ( C B ˉ , . . . , C B ˉ A ˉ k − 1 ) y = x ∗ K ˉ . (8) \begin{aligned} \mathbf{\bar{K}}&=(\mathbf{C}\mathbf{\bar{B}},...,\mathbf{C}\mathbf{\bar{B}}\mathbf{\bar{A}}^{k-1}) \\ y&=x \ast \mathbf{\bar{K}}. \tag{8} \end{aligned} Kˉy=(CBˉ,...,CBˉAˉk1)=xKˉ.(8)
因此每个forward就是一次卷积操作。

3. Mamba中的Selective SSM对SSM做了什么改进?

B , C → B t , C t B, C \rightarrow B_t, C_t B,CBt,Ct: SSM是时不变模型,即模型参数 A , B , C A, B, C A,B,C在每个时间步相同,与输入无关,因此难以进行content-aware reasoning。在论文1中,作者列举了两个的任务:Copying & Selective Copying来证明content-aware learning的重要性,如下图所示:
在这里插入图片描述
时不变模型可以完成Copying任务,但无法完成Selective Copying,因为它无法区分当前步是有效还是无效输入。因此,作者认为模型参数应该是input-dependent,即时变的。这也是Selective SSM的改进动机。算法示意图如下所示:
在这里插入图片描述

此外,作者认为权衡序列模型的效率和精度的关键在于how well they compress their state:

In summary, the efficiency vs. effectiveness tradeoff of sequence models is characterized by how well they compress their state: efficient models must have a small state, while effective models must have a state that contains all necessary information from the context. In turn, we propose that a fundamental principle for building sequence models is selectivity: or the context-aware ability to focus on or filter out inputs into a sequential state.

例如,attention is both effective and inefficient because it explicitly does not compress context at all.

4. 如何高效实现Selective SSM?

将SSM变成时变模型后,将遇到两个问题:

  1. 无法直接利用卷积,不能享受卷积带来的高效性。因为此时“卷积核”是时变的;
  2. 参数 B , C B,C B,C为了变成 B t , C t B_t,C_t Bt,Ct,会多一个L维度(时间/序列维度),参数量的变大增加了对memory IOs带宽和显存的占用,增加了计算和显存成本。

对此,作者设计了三个改进:

  1. 相比在GPU HBM(high-bandwidth memory)中准备形状为 ( B , L , D , N ) (B, L, D, N) (B,L,D,N) ( A ˉ , B ˉ ) (\mathbf{\bar{A}}, \mathbf{\bar{B}}) (Aˉ,Bˉ),先将SSM参数 ( Δ , A , B , C ) (\mathbf{\Delta}, \mathbf{A}, \mathbf{B}, \mathbf{C}) (Δ,A,B,C)加载到更快的SRAM,在SRAM中执行discretization和recurrence,最后将结果 ( B , L , D ) (B, L, D) (B,L,D)写回HBM。这避免了在HBM和SRAM之间的带宽消耗,因为搬运 ( B , L , D , N ) (B, L, D, N) (B,L,D,N)d的参数显然比 ( B , L , D ) (B, L, D) (B,L,D)的更占用带宽。
  2. 为了避免串行recurrence减慢训练速度,采用work-efficient parallel scan algorithm并行化。
  3. 为了减少显存占用,前向传播时不保存intermediate states,而在反向传播时重新计算。

参考资料


  1. Mamba: Linear-Time Sequence Modeling with Selective State Spaces ↩︎ ↩︎

  • 11
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值