原文链接:https://arxiv.org/abs/2312.00752
1. 引言
基石模型(FM)的主干网络通常是序列模型,处理任意的输入序列。但现代FM主要基于Transformer这一序列模型,及其核心的注意力。但是,自注意力仅能在上下文窗口中密集地传递信息,而无法建模窗口外部的数据;此外,其尺度与窗口长度成二次方关系。注意力相关高效的改进牺牲了其有效性,因此也未被有效地用于不同领域。
最近,结构状态空间序列模型(SSM)作为序列建模的有前景方法,可被理解为RNN与CNN的结合,并受经典状态空间模型的启发。这类模型能高效计算,且尺度与序列长度成比例关系。此外,在部分模态下,还可建模长距离依赖关系,且在连续信号(如音频与视觉)下取得了成功。但对于离散且信息密集的数据(如文本)则不那么有效。
本文提出选择性状态空间模型,在前面的工作上做出改进,达到Transformer的建模能力,且尺度随序列长度线性增大。
选择机制:过去的方法缺乏以数据依赖的方式高效选择数据的能力(关注或忽视特定输入)。本文通过将SSM的参数基于输入参数化,设计选择机制,使模型过滤无关信息并记忆相关信息。
硬件感知的算法:所有之前的SSM需要是时不变和输入不变的,以高效计算。本文使用硬件感知的算法来克服这一问题,递归地使用扫描而非卷积计算模型,且不实现扩展状态以避免在GPU内存层次结构的不同层进行IO访问。这样,实施速度在理论上和现代硬件上均能超过过去的方法(伪线性时间)。
结构:本文将之前的SSM结构与Transformer的MLP组合为块(Manba),一种包含了选择性状态空间的简单而同质的结构设计。
选择性SSM与其扩展Manba均为完全递归的模型,适合作为以序列为输入的通用基石模型的主干网络。其关键属性为:
- 高质量:选择性能为密集模态(语言、基因组)带来强性能;
- 快速训练和推断:训练时的计算与存储尺度均随序列长度线性变化,推断时自回归地展开模型使得每步只需常数时间,因为无需过去元素的缓存。
- 长上下文:质量与效率使其能在1M长度序列上产生性能提升。
在语言、音频、基因组等领域上的实验表明,Mamba只需更少的参数量就能达到Transformer相同的性能,且速度更快。
2. 状态空间模型
结构状态空间模型(S4)与RNN、CNN以及经典状态空间模型相关。其受到特定连续系统的启发,该连续系统通过隐状态 h ( t ) ∈ R N h(t)\in\mathbb R^N h(t)∈RN映射1维函数或序列 x ( t ) ∈ R → y ( t ) ∈ R x(t)\in\mathbb R\rightarrow y(t)\in\mathbb R x(t)∈R→y(t)∈R。
S4模型由4个参数定义( Δ , A , B , C \Delta,A,B,C Δ,A,B,C),包含序列到序列的两阶段变换(式(1)):
h ′ ( t ) = A h ( t ) + B x ( t ) ( 1 a ) h t = A ˉ h t − 1 + B ˉ x t ( 2 a ) K ˉ = ( C B ˉ , C A ˉ B ˉ , ⋯ , C A ˉ k B ˉ , ⋯ ) ( 3 a ) y ( t ) = C h ( t ) ( 1 b ) y t = C h t ( 2 b ) y = x ∗ K ˉ ( 3 b ) \begin{matrix} h'(t)=Ah(t)+Bx(t)&(1a)&h_t=\bar Ah_{t-1}+\bar Bx_t&(2a)&\bar K=(C\bar B,C\bar A\bar B,\cdots,C\bar A^k\bar B,\cdots)&(3a)\\ y(t)=Ch(t)&(1b)&y_t=Ch_t&(2b)&y=x*\bar K&(3b) \end{matrix} h′(t)=Ah(t)+Bx(t)y(t)=Ch(t)(1a)(1b)ht=Aˉht−1+Bˉxtyt=Cht(2a)(2b)