从控制系统到语言模型 | Mamba 的前世今生

作者 | 紫气东来  编辑 | 汽车人

原文链接:https://zhuanlan.zhihu.com/p/680833040

点击下方卡片,关注“自动驾驶之心”公众号

戳我-> 领取自动驾驶近15个方向学习路线

>>点击进入→自动驾驶之心Mamba技术交流群

本文只做学术分享,如有侵权,联系删文

在当前 Transformer 之外的模型,势头最盛的恐怕就是 Mamba 了,而要学习和研究 Mamba 就离不开其前导模型—— 状态空间模型(State Space Models, SSM) ,而 SSM 则又是脱胎于动态控制系统。本文将从最初的起点出发,通过数学和实现细节上的分析和讨论,抽丝剥茧,以期对 Mamba 有一个根本而深刻的认识。

该篇在内容上属于前一篇的一部分。

由于前文篇幅所限,且该篇内容自成体系,还由于该篇内容的底层原理正是来源于笔者的大学专业内容,仅以此文来纪念青葱岁月与奋斗年华。

一、控制系统与状态空间模型

考虑控制系统的连续状态的时不变系统,如下:

f5a730df79ca7b0e497446c1c12dbe5a.png
连续状态的时不变系统

改图的状态方程可以表示为:

1a7412b4c2e0e8a946c668cc7446110c.png e1ca9cb482c187692b600c6b73d46b9e.png 982484d70efc60e644aa4aa82b6afc7b.png

其中D被称为 skip connection, 可以理解为经过变换的残差连接。因此通常把不包括D在内的部分特指为 SSM,其方程如下:

91fdd3a7096e253e1d58588eefa08210.png c1e706ac5c1c2fcbd6cfaa5ca8f9c9e2.png

需要说明的还有两点:

  • 该系统是一个时不变(time-invariant)系统(这对后边 mamba 的改造非常关键)

  • 该系统是一个连续(continuous)系统(在实际处理中需要离散化处理)

1.1 离散化及其三种解法

ff9ce92c93081df727d2d5de94ff1b92.png
1.1.1 方法一:微分极限法
2cd6267efd2a3e32d810fd4d7ad5da0f.png

对上式稍加变换,有

59dc80253db4e73a224157ac7f09eaf8.png

代入状态方程则有

819f85589071fdb6a6ddcccbd62eddfe.png 99771e9e5ca69663a79220623a2bb1b5.png
1.1.2 方法二:双线性变换(Tustin’s method)
a4df3c5d54873d86b374bc1c745f694d.png 00c6c1e2a430066e7d6aec40cc1c3439.png

继续代入状态方程,则有

b8fbd8a529dfadaff52305e9cba0c6d2.png 80b427f6c5fb697b394ebdf750a48d3b.png 65ee03cce9d1ee97516662ad07b2d7c8.png

这也是 SSM 论文中提供的结果,在此补充完整推理计算过程。

1.1.3 方法三:零序保持(zero- order hold, ZOH)
0a94288a76122b9c12038de8263f086e.png a99bd8018feb52a42256a7f6716b9c3b.png 46a5754e9323d9e5dc3acb2f5cd3207a.png 082be43c5fefd247f900dc5b5985ed0d.png

则有

dc9c9bed45510614a3c49f090cbb6421.png

这也是 Mamba 论文中提供的结果,在此补充完整推理计算过程。

979fa453b0626bf64eb0df2cd4bf5337.png

我们可以简单计算几步体会其过程

4f85f3056f82942ef2c3f4ede3647e77.png 4844cda049db10f35a49fbba8612c625.png 40cc9537cdf58596a36d2bc327ce7420.png

这样就会遇到 RNN 的困境,即无法实现计算的并行化。

1.2 卷积形式与训练并行化

通过前文,我们已经认识了两种形式:原始的连续形式序列的循环形式,如下图的前两部分。连续形式在实际中不易实现,循环形式可以实现推理的线性,但是无法并行训练,这就引出了第三种形式,即卷积形式。

9febcc361457f80a91d488cfb094e810.png

首先写出离散形式下的状态方程

9f350f820632be0e881c5366fe1fa866.png

由此可以写出前几步的状态变量

bc8ddf580a9c8721c4f7f347e263e8e0.png 9f791c88323451788c69b7d04f9497ed.png

同理可以写出前几步的输出变量

e279278fd8616cc493da18a5efee117e.png 023036a658facb5e429c87ca7207f503.png c1985238b19bdb841bd777083e63ac2b.png 704fad01ce13eb80198f58674d494d86.png cbda38cb3956317478dcadb9e1b346d0.png

除此之外还有一种 HiPPO(High-Order Polynomial Projection Operator) 矩阵的初始化方法,即

c7845fc05a0f722980e4df62b3a49e9c.png

通过以上的卷积形式,即可实现训练的并行性,再结合适用于推理的线性复杂度的循环形式,即构成了完整的 SSM 结构。

6e79f2ff5618e8415feefd290e070f0a.png

以上我们详细推导了三种形式,现在有必要总结一下每种形式的特点:

  • 连续形式

  • 自动处理连续数据(如音频信号、时间序列)。这在处理具有不规则或时移采样的数据时具有巨大的优势。

  • 可以进行数学上的分析,例如通过计算精确的轨迹或构建记忆系统 。

  • 训练和推理都非常慢

  • 循环形式

  • 对于序列数据的自然归纳,原则上无上下文限制

  • 高效推理(恒定时间状态更新)

  • 学习缓慢(缺乏并行性)

  • 训练过长序列时梯度消失或爆炸

  • 卷积形式

  • 本地的、可解释的特征

  • 高效(可并行化)训练

  • 在线或自回归上下文中的速度较慢(必须重新计算每个新数据点的整个输入)

  • 固定上下文大小

以上内容已经全面介绍了 SSM 的诸多细节,但对于文本生成的场景,便可以比较清楚发现其问题所在。在 Transformer 架构中,前文的信息存储在了 KV cache 中,而 SSM 中不存在类似模块,因此其循环形式并不擅长处理上下文学习相关问题。那么这就引出了 Mamba 的优化和改进。

8da05b3b75958761a774cbc48631327a.png

二、Mamba 的突围

概括而言,SSM 模型的问题主要可以总结为以下两个方面:

  1. 选择性复制:在语言模型的上下文中,选择性复制是指从给定输入中辨别相关信息并将其适当地再现或合并到生成的输出中的能力。它涉及模型从输入数据中识别和重现特定短语、实体或模式的能力,从而增强生成文本的相关性和连贯性。

  2. 归纳头:语言模型中的归纳头属于专门组件,可促进模型从输入数据中推断和概括知识的能力。与人类根据观察到的模式得出结论和进行推论的方式类似,归纳头使模型能够推断信息、理解潜在的关系,并应用学到的概念来生成更细致、更适合上下文的响应。

2.1 Mamba 的选择机制

普通 SSM 无法学习选择性复制任务的原因是它们是时不变的(内核是固定的)。这意味着可学习参数随着每个新 token 的传入而保持固定。

13ef24ecdd2b9a5d5642ab12e621741f.png

Mamba 的第一个改进就是选择性扫描(Selective Scan)操作,即放弃卷积和递归的双重属性,仅依赖于循环。具体来说就是时变参数化。矩阵 A(HiPPO 矩阵)保持不变,但 Δ、B 和 C 现在成为输入的函数。下图展示了二者选择机制的差异。

d06604b39c07184dbedcebe9aa89473d.png 0c71b35747634be9968aae606d98fde7.png

以下为 SSM+Selection 的一个简单实现:

def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args: x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns: output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

2.2 训练加速 —— 并行扫描

由于 Mamba 的上述操作,使其仅具有循环形式,而不具有卷积形式,那么其训练就无法实现并行化。这就需要研究其他方法来加速训练。每个新状态都是当前输入和先前状态的总和,后者是迄今为止所有先前状态的递归总和。

df56afec2caba8e674909064818d8dbc.png

这种对数组的前缀和运算也称为扫描操作。因此,该算法的一个简单解决方案就是简单地循环遍历数组并跟踪前缀和,每个新的和将是前缀和+当前输入。这样就有 O(N)的时间复杂度,因此不可并行化。该算法的名称称为并行扫描,下图为其工作原理:

140884850bee1b2e92eff9bc80f28fc9.png

Mamba 不可并行化(因为它是时变的),因此需要依赖循环操作。Mamba 的作者采用三种经典技术来提高循环操作速度:

  • 并行扫描算法 (Parallel Scan)

  • 核融合 (Kernel Fusion)

  • 激活重计算 (Activation Recomputation)

2.3 Mamba 结构与实现

Mamba 模型是由多层 Mamba 层连接而成,与 Transformer 模型的层非常相似。Mamba 区块的架构很大程度上受到Transformer 和 Hungry Hungry Hippo (H3) 架构的启发。

6374b8c84c3ce265bc1e781567245e40.png

Mamba 层是 H3 和门控 MLP 操作的组合。它首先将输入投影到隐藏状态维度,然后在投影维度上进行非线性卷积( SILU 激活函数)。然后计算 SSM 操作。接下来我们进行跳跃连接操作。最后,用另一个线性投影缩小张量的大小就完成了。

45031bd65ff726d82050df7a630de6bf.png

下面看一下核心部分的官方实现,如下:

class MixerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            ))

官方也提供了多个版本的 checkpoint,下面以 mamba-2.8b 为例,实现其端到端的推理过程

import time
import json
import torch
import torch.nn.functional as F

from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

model_name = "state-spaces/mamba-2.8b"
prompt = "I have a dream and I"
promptlen = 100
genlen = 100
temperature = 1.0
topk = 1
topp = 1.0
minp = 0.0
repetition_penalty = 1.1
batch =1

repeats = 3
device = "cuda:0"
dtype = torch.float16

print(f"Loading model {model_name}")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=dtype)
model.eval()
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

torch.random.manual_seed(0)
tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen

fn = lambda: model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=temperature,
        top_k=topk,
        top_p=topp,
        min_p=minp,
        repetition_penalty=repetition_penalty,
    )
out = fn()
 
print(tokenizer.batch_decode(out.sequences.tolist()))
torch.cuda.synchronize()
start = time.time()
for _ in range(repeats):
    fn()
torch.cuda.synchronize()
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
print(f"{model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")

结果如下所示:

Loading model state-spaces/mamba-2.8b
Number of parameters: 2768345600
['I have a dream and I want to make it come true.\n\n"The first thing is that we need more people in the community who are willing to help us out, because there\'s not enough of them," he said. "We\'re going through this process now with our new building."\n\nHe added: "It would be great if they could get some volunteers from the local business owners or maybe even residents themselves so that when we do go into construction mode on the project... then at least you know somebody will']
Prompt length: 6, generation length: 100
state-spaces/mamba-2.8b prompt processing + decoding time: 569ms

三、语言模型之外的 Mamba

除了在语言模型上的应用外,Mamba 还在图像和时间序列模型上有所应用,下面分别介绍相关工作。

3.1 图像中的 Mamba

处理图像本质上比处理文本要复杂得多。因为图像不仅仅是像素的序列;它们还包含复杂的模式,变化的空间关系,以及理解整体环境的需要。这种复杂性使得视觉数据的有效处理成为一项具有挑战性的任务,特别是在规模和高分辨率下。

Mamba块是Vim的一个关键特性,通过使用位置嵌入标记图像序列,并使用双向状态空间模型压缩视觉表示,Vision Mamba可以有效地捕获图像的全局上下文。这种方法解决了可视数据固有的位置敏感性,这是传统Transformer模型经常遇到的一个关键问题,特别是在更高分辨率下。

d1a97f1dd31ad9b66e2f2fdfd0a88ba3.png

Vim模型首先将输入图像划分为小块,然后将小块投影到 token 中。这些 token 随后被输入到 Vim 编码器中。对于像 ImageNet 分类这样的任务,在 token 序列中添加了一个额外的可学习分类标记,与用于文本序列建模的Mamba 模型不同,Vim编码器在正向和反向两个方向上处理 token 序列。

ViM 的突出特点之一是其双向处理能力,类似于 LSTM 的工作原理。与许多以单向方式处理数据的模型不同,ViM 的编码器可以向前和向后处理 token。双向模型可以更丰富地理解图像上下文,这是准确图像分类和分割的关键因素。

一旦 token 被卷积并激活,算法就会执行额外的线性变换并应用 softplus 函数,以确保输出值保持正值。这些转换为 SSM 序列建模功能准备 token。

在 SSM 操作之后,该算法应用门控机制,通过 SSM 输出与 SiLU 激活的前向和后向序列的元素相乘来调制信息流。这种门控机制可能旨在控制每个方向处理的贡献。

5cc4458cf0db942ad41dddb8159da186.png

最后一步是残差连接,将原始输入序列添加到门控输出中,这有助于保留前几层的信息并解决梯度消失问题。整个过程的输出是一个新的标记序列,它可能经过了复杂的转换,捕捉到了序列两个方向上的复杂依赖关系。最后,算法会返回这个经过转换的 token 序列。

在 ImageNet 分类、COCO 物体检测和 ADE20K 语义分割等基准测试中,ViM 不仅表现更好,而且效率更高。例如,在处理高分辨率图像(1248 × 1248)时,ViM 比 DEIT 快 2.8 倍,使用的 GPU 内存减少了 86%。考虑到处理高分辨率图像时通常面临的内存限制,这无疑是一个重大改进。

3.2 时间序列中的 Mamba

通过开始阶段的研究和讨论,可以看到,SSM 本来就是用来处理时间序列信号的,因此 Mamba 也可以直接用来处理时间序列的任务,以下是其简单的实现:

class TSModel(nn.Module):
def __init__(self, d_model, d_state, d_conv, expand, forecast, lookback, dropout=0.5, device="cpu"):
    super(TSModel,self).__init__()
    self.device=device
    self.mamba = Mamba(d_model=d_model, d_state=d_state,d_conv=d_conv, expand=expand).to(device)
    self.d1_nn = nn.Dropout(p=dropout).to(device)
    self.fc1=nn.Linear(in_features=lookback*d_model, out_features=forecast).to(device)

def forward(self, input):
    bs=input.shape[0]
    h_out = self.mamba(input)
    h_out = rearrange(h_out, 'b l c -> b (l c)')
    h_out = self.d1_nn(h_out)

    out = self.fc1(h_out)
    return out

def predict(self,input):
    with torch.no_grad():
        predictions=self.forward(input)
    return predictions

最后,比较一下 Transformer, RNN, Mamba 的区别:

ea3efc7044a50b91e173d5c33fd04a8c.png

投稿作者为『自动驾驶之心知识星球』特邀嘉宾,欢迎加入交流!

① 全网独家视频课程

BEV感知、毫米波雷达视觉融合多传感器标定多传感器融合多模态3D目标检测车道线检测轨迹预测在线高精地图世界模型点云3D目标检测目标跟踪Occupancy、cuda与TensorRT模型部署大模型与自动驾驶Nerf语义分割自动驾驶仿真、传感器部署、决策规划、轨迹预测等多个方向学习视频(扫码即可学习

447e0d096ceef07df79c0fd73141aee9.png

网页端官网:www.zdjszx.com

② 国内首个自动驾驶学习社区

国内最大最专业,近3000人的交流社区,已得到大多数自动驾驶公司的认可!涉及30+自动驾驶技术栈学习路线,从0到一带你入门自动驾驶感知2D/3D检测、语义分割、车道线、BEV感知、Occupancy、多传感器融合、多传感器标定、目标跟踪)、自动驾驶定位建图SLAM、高精地图、局部在线地图)、自动驾驶规划控制/轨迹预测等领域技术方案大模型、端到端等,更有行业动态和岗位发布!欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频

7840778eeb59c52d8c00b9b03b235df4.png

③【自动驾驶之心】技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦感知、定位、融合、规控、标定、端到端、仿真、产品经理、自动驾驶开发、自动标注与数据闭环多个方向,目前近60+技术交流群,欢迎加入!

自动驾驶感知:目标检测、语义分割、BEV感知、毫米波雷达视觉融合、激光视觉融合、车道线检测、目标跟踪、Occupancy、深度估计、transformer、大模型、在线地图、点云处理、模型部署、CUDA加速等技术交流群;

多传感器标定:相机在线/离线标定、Lidar-Camera标定、Camera-Radar标定、Camera-IMU标定、多传感器时空同步等技术交流群;

多传感器融合:多传感器后融合技术交流群;

规划控制与预测:规划控制、轨迹预测、避障等技术交流群;

定位建图:视觉SLAM、激光SLAM、多传感器融合SLAM等技术交流群;

三维视觉:三维重建、NeRF、3D Gaussian Splatting技术交流群;

自动驾驶仿真:Carla仿真、Autoware仿真等技术交流群;

自动驾驶开发:自动驾驶开发、ROS等技术交流群;

其它方向:自动标注与数据闭环、产品经理、硬件选型、求职面试、自动驾驶测试等技术交流群;

扫码添加汽车人助理微信邀请入群,备注:学校/公司+方向+昵称(快速入群方式)

dca5de2c457233c218024bdb17b5ebcf.jpeg

④【自动驾驶之心】硬件专场

a149b25a7d237bf604fbba153cff0c5c.jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值