Mamba复现与代码解读

🐍 在阅读代码前,建议先去了解下Mamba的原理~ Mamba 基础讲解【SSM,LSSL,S4,S5,Mamba】

环境配置

环境说明

Linux
cuda 11.8
torch==2.1.1 
torchvision==0.16.1
causal-conv1d==1.1.1
mamba-ssm==1.1.1

查看gcc版本(要保证gcc版本不得低于9)
在这里插入图片描述


创建一个名为mamba的虚拟环境

conda create -n mamba python=3.10.13
conda activate mamba

安装cudatoolkit

conda install cudatoolkit==11.8 -c nvidia

在这里插入图片描述
安装pytorch和torchvision

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118

在这里插入图片描述

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

在这里插入图片描述

conda install packaging

在这里插入图片描述
安装conv1d和mamba-ssm

pip install causal-conv1d==1.1.1
pip install mamba-ssm==1.1.1

这里的conv1d 和mamba-ssm的版本一定要对应上,不然会报错。

demo推理

  • 安装Transformers

方法1:直接git后安装

pip install git+https://github.com/huggingface/transformers@main

方法2:从源码安装
从官网上下载https://github.com/huggingface/transformers到本地,解压后,进入到主目录中,然后输入如下命令

pip install -v -e .

方法3: pip安装

pip install transformers
  • 下载mamba的权重文件

Mamba 预训练权重下载地址:Mama- Pretrained models on Hugging Face
在这里插入图片描述
将上面红框中的文件下载后放入到一个文件夹中,文件名为mamba-130m-hf

  • 测试demo
    新建一个demo.py文件内容如下:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
path="./path/to/mamba-130m-hf" # 修改这里的路径
# 加载模型
tokenizer = AutoTokenizer.from_pretrained(path,local_files_only=True)
model = MambaForCausalLM.from_pretrained(path,local_files_only=True)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"]
# 生成输出结果
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

然后运行demo.py, 得到如下输出结果。
在这里插入图片描述

源码解析

这里讲解的是github上的一个关于mamba最小实现的仓库,重点在于理解Mamba架构的实现。https://github.com/johnma2006/mamba-minimal

参数解读

  • b: 批量大小(batch size), 对应Mamba论文中algorithm 2中的B
  • l : 序列长度,对应Mamba论文中algorithm 2中的L
  • d或者d_model: 隐藏层的维度大小
  • n或者d_state: 状态维度,对应Mamba论文中algorithm 2中的N
  • expand : 扩张系数,Mamba论文3.4节的E
  • d_in或者d_inner : d*expand, 对应Mamba论文中algorithm 2中的D
  • A,B,C,D对应的是状态空间模型的参数。其中B,C是依赖于输入的,A,D并不是。
  • Δ 或者 delta : 依赖于输入的时间步长。
  • dt_rank: Δ的秩,对应Mamba论文中3.6节的“parameterization of Δ”

在这里插入图片描述

Mamba块(Mamba Block)

状态空间模型(SSM)


代码解析: ssm模型实现了Algorithm 2的整个步骤。其中第5,6步调用的是selective_scan函数实现的选择性扫描算法。

    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D 是独立于输入的 (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C 是依赖于输入的 (这是Mamba模型和 linear time invariant S4 的主要区别,这也是为什么Mamba被称为selective state spaces
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float() # (d_in,)

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)   
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
        # delta: (b, l, dt_rank) 
        # B, C: (b, l, n)  
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        y = self.selective_scan(x, delta, A, B, C, D)  # 选择性扫描算法
        return y

delta = F.softplus(self.dt_proj(delta))参考论文中描述的:
在这里插入图片描述
Δ \Delta Δ 在SSM中的作用,类似于RNN中的门控机制。

选择性扫描算法(selective_scan)


Selective SSM 原理介绍
下图是Mamba论文中的算法介绍:
上图中算法的核心是第5步和第6步:

  • 第5步是对连续的矩阵A,B进行离散化得到离散化后的矩阵 A ˉ \bar A Aˉ B ˉ \bar B Bˉ, 离散化的方法有欧拉方法和零阶保持(Zero-order hold, ZOH)方法。
    下面的公式是ZOH算法:
    A ˉ = exp ⁡ ( Δ A ) B ˉ = ( Δ A ) − 1 ( exp ⁡ ( Δ A ) − I ) ⋅ Δ B \bar{A}=\exp (\Delta A) \quad \bar{B}=(\Delta A)^{-1}(\exp (\Delta A)-I) \cdot \Delta B Aˉ=exp(ΔA)Bˉ=(ΔA)1(exp(ΔA)I)ΔB

  • 第6步是对离散化后的矩阵 A ˉ \bar A Aˉ B ˉ \bar B Bˉ以及C执行SSM算法,其中离散的SSM方程如下
    x ( t + 1 ) = A ˉ x ( t ) + B ˉ u ( t ) y ( t ) = C x ( t ) + D u ( t ) x(t + 1) = \bar A x(t) + \bar B u(t) \\ y(t) = Cx(t) + Du(t) x(t+1)=Aˉx(t)+Bˉu(t)y(t)=Cx(t)+Du(t)
    其中 x x x表示的隐藏状态, u u u表示的输入, y y y表示的是输出。


selective_scan代码解析
下面的算法主要实现的就是Algorithm2中的第5步和第六步。
在第五步中,代码中采用ZOH对矩阵A进行离散化,但是作者并没有采用ZOH对B进行离散化,而是采用了一种更简化的方式(因为主要的参数是A, 对B进行简化并不会影响实验的性能)
在第六步中,代码中使用for循环的方式执行SSM, 主要是为了说明SSM的核心功能,其并行扫描算法可以参见源码(https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86

    def selective_scan(self, u, delta, A, B, C, D):
    '''    
        Args:
            u: shape (b, l, d_in)    输入x,(B,L,D)
            delta: shape (b, l, d_in)  离散步长,(B,L,D)
            A: shape (d_in, n)  连续的矩阵A,(D,N)
            B: shape (b, l, n)  连续的矩阵B(B,L,N)
            C: shape (b, l, n)  (B,L,N)
            D: shape (d_in,)    (D,)
    
        Returns:
            output: shape (b, l, d_in)   输出:(B,L,D)
    
        官方实现版本:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
       '''      
      
        (b, l, d_in) = u.shape #(B,L,D)
        n = A.shape[1] # N
        
        '''
        对连续的参数(A, B)进行离散化
        A 使用零阶保持法(zero-order hold, ZOH)进行离散化 (see Section 2 Equation 4 in the Mamba paper [1])
        B 则使用一种简化的Euler方法进行离散化
        B没有使用ZOH的原因,作者解释如下: "A is the more important term and the performance doesn't change much with the simplification on B"
        '''
        # einsum 操作实际上是将 delta 和 A 张量的最后两个维度进行矩阵乘法,并在前面添加了两个维度(b 和 l)
        # torch.exp 函数对这个张量中的每个元素进行指数化,即将每个元素取指数值。
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))   # (B,L,D) * (D,N) -> (B,L,D,N)
        # 将 delta、B 和 u 张量的对应位置元素相乘,并在最后一个维度上进行求和,输出一个新的张量。
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')  # (B,L,D)*(B,L,N)*(B,L,D)->(B,L,D,N)
        
        '''
        执行 selective scan (see scan_SSM() in The Annotated S4 [2])
        # 注意,下面的代码是顺序执行的, 然而在官方代码中使用更快的并行扫描算法实现的(类似于FlashAttention,采用硬件感知扫描)。
        '''
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):  # 这里使用for循环的方式只用来说明核心的逻辑,原代码中采用并行扫描算法
            x = deltaA[:, i] * x + deltaB_u[:, i] # x(t + 1) = Ax(t) + Bu(t)
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') # y(t) = Cx(t)  (B,D,N)*(B,N)->(B,D) 
            ys.append(y)
        y = torch.stack(ys, dim=1)  # 大小 (b, l, d_in)  (B,L,D)
        y = y + u * D # y(t) = Cx(t)+Du(t)
        return y #(B,L,D)

前向传播(forward)


代码解析
在这里插入图片描述
这段代码实现的就是上图的架构。

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    
    
        Returns:
            output: shape (b, l, d)  
            
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape # shape (b,l,d)

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)# x: (b,l,d_in), res: (b,l,d_in)

        x = rearrange(x, 'b l d_in -> b d_in l')  # shape (b,l,d_in)->(b,d_in,l)
        x = self.conv1d(x)[:, :, :l]     # (b,d_in,l)
        x = rearrange(x, 'b d_in l -> b l d_in') # (b,d_in,l)->(b,l,d_in)
        x = F.silu(x) # (b,l,d_in)
        y = self.ssm(x) # (b,l,d_in)
        y = y * F.silu(res) # (b,l,d_in)
        
        output = self.out_proj(y) # (b,l,d_in)-> (b,l,d)
        return output # (b,l,d)

均方根归一化 (RMSNorm)


RMS Normalization 简介
RMS(Root Mean Square)Normalization(均方根归一化)是一种用于神经网络中的归一化技术,旨在将输入数据的分布调整为更适合训练的状态。它与 Batch Normalization(批归一化)和 Layer Normalization(层归一化)等技术类似,但有一些独特之处。

对于输入张量 X X X 的形状为 ( B , L , D (B, L, D (B,L,D ),其中 B B B 是批量大小 (batch size), L L L 是序列长度, D D D是每个样本的特征维度。RMS归一化的过程如下:

  1. 计算每个样本的均方根值:
    RMS ⁡ ( X b , i ) = 1 D ∑ d = 1 D X b , i , d 2 \operatorname{RMS}\left(X_{b, i}\right)=\sqrt{\frac{1}{D} \sum_{d=1}^D X_{b, i, d}^2} RMS(Xb,i)=D1d=1DXb,i,d2

其中, X b , i X_{b, i} Xb,i 表示输入张量的第 b b b 个样本中的第 i i i 个样本, X b , i , d X_{b, i, d} Xb,i,d 表示该样本的第 d d d 个特征值。

  1. 对每个样本进行归一化:
    RMS ⁡ _ Norm ⁡ ( X b , i ) = X b , i RMS ⁡ ( X b , i ) \operatorname{RMS} \_\operatorname{Norm}\left(X_{b, i}\right)=\frac{X_{b, i}}{\operatorname{RMS}\left(X_{b, i}\right)} RMS_Norm(Xb,i)=RMS(Xb,i)Xb,i
    这里将每个特征值除以其所在样本的均方根值, 从而使得每个样本的均方根值归一化后为 1 。

RMS归一化的优点包括:

  • 适用于小批量数据: RMS归一化对于小批量数据或数据量较少的情况更为适用,因为它只关注单个样本的统计信息。
  • 适用于变长序列: 由于RMS归一化是对每个样本进行归一化,因此对于变长序列的处理更加简便。

RMS归一化也有一些局限性,例如在某些情况下可能会导致梯度爆炸或消失问题,并且对于较大的数据集可能不如Batch Normalization效果好。因此,选择归一化技术时需要根据具体情况进行权衡和选择。


RMSNorm代码解析
RMS归一化的过程如下:

  • 对于输入张量 x x x,首先计算其每个样本在最后一个维度(通常是特征维度)上的平方: x 2 x^2 x2
  • 然后计算每个样本在最后一个维度上的平均值: x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) x^2.mean(-1, keepdim=True) x2.mean(1,keepdim=True)。这里的-1表示最后一个维度,keepdim=True表示保持维度不变。
  • 加上一个很小的常数 ϵ \epsilon ϵ,以防止除以零: x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) + ϵ x^2.mean(-1, keepdim=True) + \epsilon x2.mean(1,keepdim=True)+ϵ
  • x x x除以上述结果的平方根,得到归一化的系数: t o r c h . r s q r t ( x 2 . m e a n ( − 1 , k e e p d i m = T r u e ) + ϵ ) torch.rsqrt(x^2.mean(-1, keepdim=True) + \epsilon) torch.rsqrt(x2.mean(1,keepdim=True)+ϵ)
  • 最后乘以可学习的权重weight,得到最终的归一化结果。
class RMSNorm(nn.Module):
    '''
    均方根归一化: RMS normalization
    '''
    def __init__(self,
                 d_model: int, # hidden dim
                 eps: float = 1e-5): # 防止除以零的小数值
        super().__init__()
        self.eps = eps
        # weight: 可学习的参数,调整归一化后的值
        self.weight = nn.Parameter(torch.ones(d_model)) # 初始值为大小为d_model的张量,每个元素的值都是1

    def forward(self, x):
        '''
        :param x: 输入张量
        :return: output 均方根规划化后的值

        RMS的计算步骤:
        Step1: 计算每个样本的均方根值
            Step1.1: 先计算x的平方
            Step1.2: 沿着最后一个维度(通常是特征维度)计算平均值,并加上一个很小的数eps
            Step1.3: 最后取平方根
        Step2: 对每个样本进行归一化
            Step2.1:每个特征值除以其所在样本的均方根值
            Step2.2: 最后乘以可以学习的权重weight,得到最终的输出
        '''
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
        return output

残差块(ResidualBlock)


ResidualBlock代码解析
为Mamba Block 添加 normalization 和 残差连接 (下图红框中的部分)

在这里插入图片描述

class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args) # Mamba块
        self.norm = RMSNorm(args.d_model) # RMS归一化    

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (batch size, sequence length, hidden dim)
        Returns:
            output: shape (b, l, d)  (batch size, sequence length, hidden dim)    
        """
        output = self.mixer(self.norm(x)) + x # [Norm -> Mamba -> Add]  
        return output

Mamba架构

在这里插入图片描述
下面的代码主要是实现的就是上图中红框标注出的部分

class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model) # 词嵌入,其中包含 `args.vocab_size` 个不同的词或标记,每个词嵌入的维度为 `args.d_model`
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])# n_layer个ResidualBlock
        self.norm_f = RMSNorm(args.d_model) #RMS归一化

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper


    def forward(self, input_ids):
        """
        Args:
            input_ids (long tensor): shape (b, l)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        x = self.embedding(input_ids) # (b,l,d)  生成词嵌入
        
        for layer in self.layers:   # 通过n_layer个ResidualBlock
            x = layer(x)   # (b,l,d)
            
        x = self.norm_f(x) # (b,l,d)
        logits = self.lm_head(x) # (b,l,vocab_size)

        return logits

其中,nn.Embedding(args.vocab_size, args.d_model) 创建了一个词嵌入层,其中包含 args.vocab_size 个不同的词或标记,每个词嵌入的维度为 args.d_model。这意味着每个词或标记将被表示为一个长度为 args.d_model 的向量。

评论 27
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zyw2002

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值