本文提供了从零开始构建 Mamba 的全部代码过程,作者将Mamba算法模型从理论转化为具体实践。这一探索过程不仅可以巩固对 Mamba 内部工作原理的理解,而且还展示了新颖算法模型架构的实际设计步骤。
在深度学习领域,序列建模仍然是一项具有挑战性的任务,通常由 LSTM 和 Transformers 等模型来解决。然而,这些模型的计算量很大,因此在实际应用场景中,这些模型方法仍存在巨大的缺陷。而Mamba 是一个线性时间序列建模框架,其旨在提高序列建模的效率和有效性。本文将深入探讨使用 PyTorch 实现 Mamba 的过程,解码这一创新方法背后的技术问题和代码。
1 模型架构对比
1.1 Transformer:
Transformer因其注意机制而闻名。借助于Transformer操作特性,特征序列中的任何部分都可以与其他部分进行动态交互,尤其是因果注意力特征,能够很好的捕获因果特征的信息。因此,Transformer能够处理好序列中的每一个元素,相应的,Transformer的计算代价和内存成本也都很高,与序列长度(L²)的平方成正比。
1.2 递归神经网络(RNN):
RNN 是按照序列顺序更新隐藏状态,它只考虑当前输入特征和上一个隐藏状态信息。这种方法允许它们以恒定的内存成本处理无限长的序列。然而,RNN 的简单性也变相的成为一个缺点,即限制了其记忆长期依赖关系的能力。此外,尽管有 LSTM 这样的创新,RNN 中的时间反向传播(BPTT)机制可能会占用大量内存,并可能出现梯度消失或爆炸的问题。
1.3.状态空间模型(S4):
状态空间模型具有良好的特性。它们提供了一种计算代价和内存成本的平衡,比 RNNs 更有效地捕捉长程依赖性,同时比 Transformers 更节省内存。
图1|序列建模网络架构发展©️【深蓝AI】
1.4.Mamba架构的方法思路:
●选择性状态空间:Mamba 以状态空间模型的概念为基础,引入了一种新的模型架构设计思路。它利用选择性状态空间,能更高效、更有效地捕捉长序列中的相关信息。
●线性时间复杂性:与Transformers不同,Mamba的运行时间与序列长度成线性关系。这一特性使其特别适用于超长序列的任务,而传统的模型在这方面会很吃力。
图2|Mamba引入选择性状态空间©️【深蓝AI】
Mamba 通过其 "选择性状态空间"(Selective State Spaces)的概念,为传统的状态空间模型引入了一个新颖的架构。这种方法稍微放宽了标准状态空间模型的僵化状态转换,使其更具适应性和灵活性,有点类似于 LSTM。不过,Mamba 保留了状态空间模型的高效计算特性,使其能够一次性完成整个序列的前向传递。
2 代码实现
2.1导入必须的库文件
在简单介绍完Mamba架构之后,为大家带来Mamba的代码实现过程,首先导入必须的库。
2.2 设置标识和训练设备
这里主要针对是否使用GPU,以及Mamba的选择设定对应的表示、以及所使用的设备。
2.3 设置初始化超参数
这一小节定义了模型维度(d_model)、状态大小、序列长度和批次大小等超参数。
2.4 定义S6模块
S6 模块是 Mamba 架构中的一个复杂组件,它主要由一系列线性变换和离散化过程组成,用于处理输入的特征序列。它在捕捉序列的时间动态特征方面起着至关重要的作用,而时间动态特征是语言建模等序列建模任务的一个关键方面。
S6 模块继承于 nn.Module,是 Mamba 算法模型的关键部分,负责处理离散化过程和前向传播。
2.5 定义MambaBlock模块
MambaBlock 模块是一个定制的神经网络模块,是 Mamba 模型的关键部件,它封装了处理输入数据的多个网络层和操作函数。MambaBlock 模块代表一个复杂的神经网络模块,包括线性投影、卷积、激活函数、自定义 S6 模块和残差连接。该模块是 Mamba 模型的基本组成部分,通过一系列转换处理输入序列,以捕捉数据中的相关模式和特征。这些不同网络层和操作函数的组合使 MambaBlock 能够有效处理复杂的序列建模任务。
MambaBlock 模块是另一个封装了 Mamba 核心功能的模块,包括输入投影、一维卷积和 S6 模块。
2.6 定义Mamba模型
Mamba 类代表 Mamba 模型的整体架构,由一系列 MambaBlock 模块组成。每个模块负责处理输入的序列数据,一个模块的输出作为下一个模块的输入。这种顺序处理使模型能够捕捉输入数据中的复杂模式和关系,从而有效地完成顺序建模的任务。多个模块的堆叠是深度学习架构中常见的设计,因为它能让模型学习数据的分层表示特征。
该类定义了整个 Mamba 模型,将多个 MambaBlock 模块链接在一起,构成整体算法模型的架构。
2.7 定义RMSNorm模块
RMSNorm 模块是一个自定义的归一化层,继承了 PyTorch 的 nn.Module。该层用于对神经网络的激活值进行归一化操作,这有助于加快训练速度。
RMSNorm 模块是用于归一化的均方根网络层,是神经网络架构中的一种常用技术。
3 使用介绍
本节介绍如何在简单的数据样本上实例化和使用 Mamba 算法模型。
3.1数据准备和训练函数
Enwiki8Dataset 类是一个自定义数据集处理程序,它继承自 PyTorch 的 Dataset 类,专门用于为序列建模任务(如语言建模)而构建的数据集。
train 函数用于训练 Mamba 算法模型。
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
●model(模型):要训练的神经网络模型(本例中为 Mamba);
●tokenizer:处理输入数据的标记符;
●data_loader:数据加载器,一个可迭代器,用于为训练提供成批数据;
●optimizer: 优化器:用于更新模型权重的优化算法;
●criterion:用于评估模型性能的损失函数;
●设备:模型运行的设备(CPU 或 GPU);
●max_grad_norm:用于梯度剪切的值,以防止梯度爆炸;
●DEBUGGING_IS_ON:启用调试信息的标志。
3.2 模型训练循环
上述代码是建立和训练 Mamba 模型的详细示例过程,包括数据集的组合和划分,模型的定义和初始化,以及损失函数和优化器的定义,最后则是设定训练循环的次数。
4 总结
本文提供了从零开始构建 Mamba 的全部代码过程,读者们可以借助本文的讲解和代码,将Mamba算法模型从理论转化为具体实践。这一探索过程不仅可以巩固对 Mamba 内部工作原理的理解,而且还展示了新颖算法模型架构的实际设计步骤。通过这种实践方法,笔者发现了序列建模的细微差别以及 Mamba 在这一领域引入的效率。有了这些知识,笔者现在就可以在自己的项目中更好地尝试使用 Mamba,或更深入地开发新型的AI模型。
参考:
【1】https://arxiv.org/abs/2312.00752
【2】https://github.com/state-spaces/mamba
【3】https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
【4】https://huggingface.co/datasets/enwik8