文章目录
🐍 在阅读代码前,建议先去了解下Mamba的原理~ Mamba 基础讲解【SSM,LSSL,S4,S5,Mamba】
环境配置
环境说明
Linux
cuda 11.8
torch==2.1.1
torchvision==0.16.1
causal-conv1d==1.1.1
mamba-ssm==1.1.1
查看gcc版本(要保证gcc版本不得低于9)
创建一个名为mamba的虚拟环境
conda create -n mamba python=3.10.13
conda activate mamba
安装cudatoolkit
conda install cudatoolkit==11.8 -c nvidia
安装pytorch和torchvision
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
安装conv1d和mamba-ssm
pip install causal-conv1d==1.1.1
pip install mamba-ssm==1.1.1
这里的conv1d 和mamba-ssm的版本一定要对应上,不然会报错。
demo推理
- 安装
Transformers
方法1:直接git后安装
pip install git+https://github.com/huggingface/transformers@main
方法2:从源码安装
从官网上下载https://github.com/huggingface/transformers到本地,解压后,进入到主目录中,然后输入如下命令
pip install -v -e .
方法3: pip安装
pip install transformers
- 下载mamba的权重文件
Mamba 预训练权重下载地址:Mama- Pretrained models on Hugging Face
将上面红框中的文件下载后放入到一个文件夹中,文件名为mamba-130m-hf
- 测试demo
新建一个demo.py文件内容如下:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
path="./path/to/mamba-130m-hf" # 修改这里的路径
# 加载模型
tokenizer = AutoTokenizer.from_pretrained(path,local_files_only=True)
model = MambaForCausalLM.from_pretrained(path,local_files_only=True)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"]
# 生成输出结果
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
然后运行demo.py
, 得到如下输出结果。
源码解析
这里讲解的是github上的一个关于mamba最小实现的仓库,重点在于理解Mamba架构的实现。https://github.com/johnma2006/mamba-minimal
参数解读
b
: 批量大小(batch size), 对应Mamba论文中algorithm 2中的B
l
: 序列长度,对应Mamba论文中algorithm 2中的L
d
或者d_model
: 隐藏层的维度大小n
或者d_state
: 状态维度,对应Mamba论文中algorithm 2中的N
expand
: 扩张系数,Mamba论文3.4节的E
d_in
或者d_inner
: d*expand, 对应Mamba论文中algorithm 2中的D
A
,B
,C
,D
对应的是状态空间模型的参数。其中B,C是依赖于输入的,A,D并不是。Δ
或者delta
: 依赖于输入的时间步长。dt_rank
: Δ的秩,对应Mamba论文中3.6节的“parameterization of Δ”
Mamba块(Mamba Block)
状态空间模型(SSM)
代码解析: ssm
模型实现了Algorithm 2的整个步骤。其中第5,6步调用的是selective_scan
函数实现的选择性扫描算法。
def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# Compute ∆ A B C D, the state space parameters.
# A, D 是独立于输入的 (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ∆, B, C 是依赖于输入的 (这是Mamba模型和 linear time invariant S4 的主要区别,这也是为什么Mamba被称为selective state spaces
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float() # (d_in,)
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
# delta: (b, l, dt_rank)
# B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # 选择性扫描算法
return y
delta = F.softplus(self.dt_proj(delta))
参考论文中描述的:
Δ
\Delta
Δ 在SSM中的作用,类似于RNN中的门控机制。
选择性扫描算法(selective_scan)
Selective SSM 原理介绍
下图是Mamba论文中的算法介绍:
上图中算法的核心是第5步和第6步:
-
第5步是对连续的矩阵A,B进行离散化得到离散化后的矩阵 A ˉ \bar A Aˉ和 B ˉ \bar B Bˉ, 离散化的方法有欧拉方法和零阶保持(Zero-order hold, ZOH)方法。
下面的公式是ZOH算法:
A ˉ = exp ( Δ A ) B ˉ = ( Δ A ) − 1 ( exp ( Δ A ) − I ) ⋅ Δ B \bar{A}=\exp (\Delta A) \quad \bar{B}=(\Delta A)^{-1}(\exp (\Delta A)-I) \cdot \Delta B Aˉ=exp(ΔA)Bˉ=(ΔA)−1(exp(ΔA)−I)⋅ΔB -
第6步是对离散化后的矩阵 A ˉ \bar A Aˉ和 B ˉ \bar B Bˉ以及C执行SSM算法,其中离散的SSM方程如下
x ( t + 1 ) = A ˉ x ( t ) + B ˉ u ( t ) y ( t ) = C x ( t ) + D u ( t ) x(t + 1) = \bar A x(t) + \bar B u(t) \\ y(t) = Cx(t) + Du(t) x(t+1)=Aˉx(t)+Bˉu(t)y(t)=Cx(t)+Du(t)
其中 x x x表示的隐藏状态, u u u表示的输入, y y y表示的是输出。
selective_scan代码解析
下面的算法主要实现的就是Algorithm2中的第5步和第六步。
在第五步中,代码中采用ZOH对矩阵A进行离散化,但是作者并没有采用ZOH对B进行离散化,而是采用了一种更简化的方式(因为主要的参数是A, 对B进行简化并不会影响实验的性能)
在第六步中,代码中使用for循环的方式执行SSM, 主要是为了说明SSM的核心功能,其并行扫描算法可以参见源码(https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86)
def selective_scan(self, u, delta, A, B, C, D):
'''
Args:
u: shape (b, l, d_in) 输入x,(B,L,D)
delta: shape (b, l, d_in) 离散步长,(B,L,D)
A: shape (d_in, n) 连续的矩阵A,(D,N)
B: shape (b, l, n) 连续的矩阵B(B,L,N)
C: shape (b, l, n) (B,L,N)
D: shape (d_in,) (D,)
Returns:
output: shape (b, l, d_in) 输出:(B,L,D)
官方实现版本:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
'''
(b, l, d_in) = u.shape #(B,L,D)
n = A.shape[1] # N
'''
对连续的参数(A, B)进行离散化
A 使用零阶保持法(zero-order hold, ZOH)进行离散化 (see Section 2 Equation 4 in the Mamba paper [1])
B 则使用一种简化的Euler方法进行离散化
B没有使用ZOH的原因,作者解释如下: "A is the more important term and the performance doesn't change much with the simplification on B"
'''
# einsum 操作实际上是将 delta 和 A 张量的最后两个维度进行矩阵乘法,并在前面添加了两个维度(b 和 l)
# torch.exp 函数对这个张量中的每个元素进行指数化,即将每个元素取指数值。
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) # (B,L,D) * (D,N) -> (B,L,D,N)
# 将 delta、B 和 u 张量的对应位置元素相乘,并在最后一个维度上进行求和,输出一个新的张量。
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') # (B,L,D)*(B,L,N)*(B,L,D)->(B,L,D,N)
'''
执行 selective scan (see scan_SSM() in The Annotated S4 [2])
# 注意,下面的代码是顺序执行的, 然而在官方代码中使用更快的并行扫描算法实现的(类似于FlashAttention,采用硬件感知扫描)。
'''
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l): # 这里使用for循环的方式只用来说明核心的逻辑,原代码中采用并行扫描算法
x = deltaA[:, i] * x + deltaB_u[:, i] # x(t + 1) = Ax(t) + Bu(t)
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') # y(t) = Cx(t) (B,D,N)*(B,N)->(B,D)
ys.append(y)
y = torch.stack(ys, dim=1) # 大小 (b, l, d_in) (B,L,D)
y = y + u * D # y(t) = Cx(t)+Du(t)
return y #(B,L,D)
前向传播(forward)
代码解析
这段代码实现的就是上图的架构。
def forward(self, x):
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
Args:
x: shape (b, l, d)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape # shape (b,l,d)
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)# x: (b,l,d_in), res: (b,l,d_in)
x = rearrange(x, 'b l d_in -> b d_in l') # shape (b,l,d_in)->(b,d_in,l)
x = self.conv1d(x)[:, :, :l] # (b,d_in,l)
x = rearrange(x, 'b d_in l -> b l d_in') # (b,d_in,l)->(b,l,d_in)
x = F.silu(x) # (b,l,d_in)
y = self.ssm(x) # (b,l,d_in)
y = y * F.silu(res) # (b,l,d_in)
output = self.out_proj(y) # (b,l,d_in)-> (b,l,d)
return output # (b,l,d)
均方根归一化 (RMSNorm)
RMS Normalization 简介
RMS(Root Mean Square)Normalization
(均方根归一化)是一种用于神经网络中的归一化技术,旨在将输入数据的分布调整为更适合训练的状态。它与 Batch Normalization(批归一化)和 Layer Normalization(层归一化)等技术类似,但有一些独特之处。
对于输入张量 X X X 的形状为 ( B , L , D (B, L, D (B,L,D ),其中 B B B 是批量大小 (batch size), L L L 是序列长度, D D D是每个样本的特征维度。RMS归一化的过程如下:
- 计算每个样本的均方根值:
RMS ( X b , i ) = 1 D ∑ d = 1 D X b , i , d 2 \operatorname{RMS}\left(X_{b, i}\right)=\sqrt{\frac{1}{D} \sum_{d=1}^D X_{b, i, d}^2} RMS(Xb,i)=D1d=1∑DXb,i,d2
其中, X b , i X_{b, i} Xb,i 表示输入张量的第 b b b 个样本中的第 i i i 个样本, X b , i , d X_{b, i, d} Xb,i,d 表示该样本的第 d d d 个特征值。
- 对每个样本进行归一化:
RMS _ Norm ( X b , i ) = X b , i RMS ( X b , i ) \operatorname{RMS} \_\operatorname{Norm}\left(X_{b, i}\right)=\frac{X_{b, i}}{\operatorname{RMS}\left(X_{b, i}\right)} RMS_Norm(Xb,i)=RMS(Xb,i)Xb,i
这里将每个特征值除以其所在样本的均方根值, 从而使得每个样本的均方根值归一化后为 1 。
RMS归一化的优点包括:
- 适用于小批量数据: RMS归一化对于小批量数据或数据量较少的情况更为适用,因为它只关注单个样本的统计信息。
- 适用于变长序列: 由于RMS归一化是对每个样本进行归一化,因此对于变长序列的处理更加简便。
RMS归一化也有一些局限性,例如在某些情况下可能会导致梯度爆炸或消失问题,并且对于较大的数据集可能不如Batch Normalization效果好。因此,选择归一化技术时需要根据具体情况进行权衡和选择。
RMSNorm代码解析
RMS归一化的过程如下:
- 对于输入张量 x x x,首先计算其每个样本在最后一个维度(通常是特征维度)上的平方: x 2 x^2 x2。
- 然后计算每个样本在最后一个维度上的平均值: x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) x^2.mean(-1, keepdim=True) x2.mean(−1,keepdim=True)。这里的-1表示最后一个维度,keepdim=True表示保持维度不变。
- 加上一个很小的常数 ϵ \epsilon ϵ,以防止除以零: x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) + ϵ x^2.mean(-1, keepdim=True) + \epsilon x2.mean(−1,keepdim=True)+ϵ。
- 让 x x x除以上述结果的平方根,得到归一化的系数: t o r c h . r s q r t ( x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) + ϵ ) torch.rsqrt(x^2.mean(-1, keepdim=True) + \epsilon) torch.rsqrt(x2.mean(−1,keepdim=True)+ϵ)。
- 最后乘以可学习的权重weight,得到最终的归一化结果。
class RMSNorm(nn.Module):
'''
均方根归一化: RMS normalization
'''
def __init__(self,
d_model: int, # hidden dim
eps: float = 1e-5): # 防止除以零的小数值
super().__init__()
self.eps = eps
# weight: 可学习的参数,调整归一化后的值
self.weight = nn.Parameter(torch.ones(d_model)) # 初始值为大小为d_model的张量,每个元素的值都是1
def forward(self, x):
'''
:param x: 输入张量
:return: output 均方根规划化后的值
RMS的计算步骤:
Step1: 计算每个样本的均方根值
Step1.1: 先计算x的平方
Step1.2: 沿着最后一个维度(通常是特征维度)计算平均值,并加上一个很小的数eps
Step1.3: 最后取平方根
Step2: 对每个样本进行归一化
Step2.1:每个特征值除以其所在样本的均方根值
Step2.2: 最后乘以可以学习的权重weight,得到最终的输出
'''
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
残差块(ResidualBlock)
ResidualBlock代码解析
为Mamba Block 添加 normalization 和 残差连接 (下图红框中的部分)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
self.args = args
self.mixer = MambaBlock(args) # Mamba块
self.norm = RMSNorm(args.d_model) # RMS归一化
def forward(self, x):
"""
Args:
x: shape (b, l, d) (batch size, sequence length, hidden dim)
Returns:
output: shape (b, l, d) (batch size, sequence length, hidden dim)
"""
output = self.mixer(self.norm(x)) + x # [Norm -> Mamba -> Add]
return output
Mamba架构
下面的代码主要是实现的就是上图中红框标注出的部分
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.d_model) # 词嵌入,其中包含 `args.vocab_size` 个不同的词或标记,每个词嵌入的维度为 `args.d_model`
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])# n_layer个ResidualBlock
self.norm_f = RMSNorm(args.d_model) #RMS归一化
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward(self, input_ids):
"""
Args:
input_ids (long tensor): shape (b, l)
Returns:
logits: shape (b, l, vocab_size)
Official Implementation:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
x = self.embedding(input_ids) # (b,l,d) 生成词嵌入
for layer in self.layers: # 通过n_layer个ResidualBlock
x = layer(x) # (b,l,d)
x = self.norm_f(x) # (b,l,d)
logits = self.lm_head(x) # (b,l,vocab_size)
return logits
其中,nn.Embedding(args.vocab_size, args.d_model)
创建了一个词嵌入层,其中包含 args.vocab_size
个不同的词或标记,每个词嵌入的维度为 args.d_model
。这意味着每个词或标记将被表示为一个长度为 args.d_model
的向量。