【手撕代码】Mamba源码讲解


论文链接: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Mamba源码附在文章末尾处)
在这里插入图片描述

基础模型现在为深度学习中大多数令人兴奋的应用提供支持,它们几乎普遍基于 Transformer 架构及其核心注意模块。许多次二次时间架构(例如线性注意、门控卷积和循环模型以及结构化状态空间模型 (SSM))已被开发出来以解决 Transformer 在长序列上的计算效率低下问题,但它们在语言等重要模态上的表现不如注意。我们发现此类模型的一个关键弱点是它们无法执行基于内容的推理,并做出了一些改进。首先,简单地让 SSM 参数成为输入的函数,用离散模态解决了它们的弱点,允许模型根据当前标记沿序列长度维度选择性地传播或忘记信息。其次,尽管这种变化阻止了高效卷积的使用,但我们设计了一种循环模式下的硬件感知并行算法。我们将这些选择性 SSM 集成到一个简化的端到端神经网络架构中,无需注意甚至 MLP 块(Mamba)。 Mamba 具有快速推理(吞吐量比 Transformers 高 5 倍)和序列长度线性扩展的优势,其性能在长达百万长度的序列上可提高真实数据的性能。作为通用序列模型主干,Mamba 在语言、音频和基因组学等多种模式下实现了最先进的性能。在语言建模方面,我们的 Mamba-3B 模型在预训练和下游评估方面均优于同等大小的 Transformers,并可与两倍于其大小的 Transformers 相媲美。


论文代码定义了一个名为 Mamba 的 Torch 模块(继承自 nn.Module),用于序列建模任务。该模块融合了状态空间模型 SSM、卷积操作和神经网络投影层,以高效地对序列数据进行处理。它可以一次性处理完整序列,也可以在自回归推理时逐步处理一个时间步的输入。
在这里插入图片描述

Mamba 模块接收形状为 (B, L, D) 的输入,其中 B 是批大小,L 是序列长度,D 是特征维度(模型维度),输出与输入同维度

  1. 线性投影 Linear Projections:将输入映射到更高维度进行处理,并在最终映射回原维度。
  2. 深度卷积 Depthwise Convolution:通过一维卷积在特征维度上进行本地上下文建模。
  3. 状态空间模型 SSM 参数化:使用参数 ABCD 来表示连续时间状态方程的离散化版本,从而捕获长程依赖。
  4. 离散化参数 dt:可学习的时间步参数,用于将连续时间状态模型离散化到离散时间步。
  5. 自定义 CUDA 内核和快速路径 Fast Path:若可用,则利用预编译的高效 CUDA 内核,加速卷积和状态更新操作。

1 __init__ 函数

def __init__(  # 可以看看初始化部分的变量
        self,
        d_model,  # 模型隐藏层的维度
        d_state=16,   # 内部状态空间的维度,默认为16
        d_conv=4,  # 定义卷积核的维度,默认为4
        expand=2,   # 定义扩展因子,默认为2
        dt_rank="auto",   # 定义输入依赖步长Δ的秩,'auto'表示自动设置
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,  # 定义卷积层使用偏置项
        bias=False,  # 定义其他层(如线性层)是否使用偏置项
        use_fast_path=True,  # 是否启用快速路径(默认为True),可以启用一些特殊的融合操作以加速计算。
        layer_idx=None,  # 表示层的索引,用于在推理阶段缓存状态。
        device=None,  # 指定使用cpu还是gpu
        dtype=None,  # 指定数据类型
    ):
        factory_kwargs = {"device": device, "dtype": dtype}  # 超参数字典
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)  # 计算内部维度,即扩展后的维度
# self.d_inner 代表了扩展后的隐藏层维度,d_model 是原始隐藏层维度,expand 是扩展因子。这里的计算表示:模型的内部维度是隐藏层维度的两倍
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank   
        # 如果,dt_rank != 'auto',根据隐藏层维度自动计算Δ的秩
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
		
        # 输入线性变化层
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
# 这是一个线性层,将输入的隐藏状态从 d_model 映射到 2 * d_inner。2 * d_inner 是为了后续操作中将输入拆分成 x 和 z 两个部分。

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"  # 指定激活函数
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(  # 将输入x映射到状态空间模型的参数Δ、B和C
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        # 将Δ从dt_rank维度映射到d_inner维度,即输入层到隐藏层
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        # 根据初始化类型 dt_init(如 "constant" 或 "random")来初始化 dt_proj 权重,保持方差不变。
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        # dt_proj 偏置项使用 torch.exp 和 torch.log 来初始化,使得 dt 在一定范围内,保持合适的初始化。        
        # dt=exp(Uniform(log(dt_min),log(dt_max)))
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # S4D real initialization
        # 
        # 这是为 SSM 初始化的矩阵 A 的对数值。A 是根据状态维度 d_state 来构建的矩阵,并对其取对数,作为模型的一部分。
        
        A = repeat( 
# torch.arange(1, self.d_state + 1) 会生成一个一维张量,其元素从 1 开始,到 d_state 结束(不包括 d_state + 1)。
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
# repeat 是 einops 库中的一个函数,它的作用是沿着新的维度对输入张量进行重复
# "n -> d n" 是 einops 的模式字符串,表示沿着新维度(即 d)重复原始张量 n 次。原始张量的维度 n 是 d_state(即张量的长度),然后沿新维度 d 重复,直到 d_inner 大小。
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log) # 将 A_log 转化为一个可训练的 PyTorch 参数,并将其赋值给 self.A_log
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D._no_weight_decay = True
'''
nn.Parameter(torch.ones(self.d_inner, device=device)):
torch.ones(self.d_inner, device=device) 创建了一个包含 d_inner 个元素的张量,每个元素初始化为 1,并且该张量被移动到指定的 device(例如 CPU 或 GPU)。然后,nn.Parameter 将这个张量包装为一个可学习的参数,使得在训练过程中,PyTorch 会自动更新 D 的值。
'''

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
# factory_kwargs 是一个包含附加参数的字典,传递给 nn.Linear 的其他超参数,通常包括 device 和 dtype(即将参数放置在哪个设备上,以及它们的数据类型)。这些参数会确保模型在不同设备(如 GPU)上运行时,数据的类型和存储位置都被正确处理。

2 _get_states_from_cache 状态缓存函数

主要作用是从缓存中获取或初始化当前层的状态信息,通常用于推理过程中。这些状态信息包含了卷积层的状态(conv_state)和状态空间模型的状态(ssm_state)。这些状态对于递归神经网络 RNN 或类似架构在时间步长之间的更新是至关重要的。

conv_statessm_state 维度都是 [B, L, D],L: 序列长度 (sequence length),D: 特征维度 (dimensionality of the features)

检查当前层的状态是否已经缓存,如果缓存中不存在状态信息,则初始化并保存它。

如果缓存中存在状态信息,则根据需要返回现有的状态。并且如果指定了 initialize_statesTrue,可以选择重新初始化(清零)这些状态。

 def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None  # self.layer_idx必须初始化
        if self.layer_idx not in inference_params.key_value_memory_dict:
# inference_params.key_value_memory_dict 是一个字典,存储了每一层的状态信息。键是层的索引(self.layer_idx),值是包含卷积状态和状态空间模型状态的元组。
# 如果 self.layer_idx 没有在 key_value_memory_dict 中找到,意味着当前层的状态信息还没有缓存或已过期。
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,  # BatchSize
                self.d_model * self.expand,  # 序列长度
                self.d_conv,  # 特征维度
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,  
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,
                # dtype=torch.float32,
            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
            # TODO: What if batch size changes between generation, and we reuse the same states?
            if initialize_states:
# initialize_states 为 True 时,表示需要重置状态,将 conv_state 和 ssm_state 全部置零。
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state

3 allocate_inference_cache 函数

为推理过程分配和初始化状态缓存。这个函数会根据批次大小 (batch_size)、序列长度 (max_seqlen),以及指定的数据类型 (dtype),为卷积状态(conv_state)和状态空间模型状态(ssm_state)分配内存并初始化它们。

 def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
# max_seqlen: 序列的最大长度。这个参数并没有在函数体内直接使用,但可能是为了与后续推理过程中的动态长度相兼容。
# dtype: 数据类型,用于指定生成的缓存状态张量的数据类型。如果未提供,默认使用卷积层权重的 dtype。
# **kwargs: 其他额外的参数,可以通过该机制传递不定长的关键字参数,但在当前实现中并未直接使用。
        device = self.out_proj.weight.device
# 获取 self.out_proj(输出线性层)的权重张量的设备信息,确保生成的状态张量与该设备保持一致。这是因为模型中的张量通常都需要位于同一设备上(例如 GPU 或 CPU)。
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
# 
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        # ssm_dtype = torch.float32
        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

4 step 单步推理函数

用于生成模型中的一个时间步(step)的输出,或者更新模型状态。它处理输入的 hidden_states,并基于当前状态(conv_statessm_state)计算下一个输出以及新的状态。

hidden_states: 输入张量,形状为 (B, 1, D),表示当前时间步的输入。B 是批次大小,D 是特征维度,1 表示当前时间步只有一个 token。

conv_state: 卷积层的状态,是一个张量,形状为 (B, D, W),其中 W 是卷积窗口的大小。

ssm_state: 状态空间模型的状态,是一个张量,形状为 (B, D, S),表示模型的状态空间。

def step(self, hidden_states, conv_state, ssm_state):
    	"""
        执行过程中单步处理
        hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度
        conv_state: 卷积的状态缓存 (B, D, W)
        ssm_state: 状态空间模型的状态缓存 (B, D, S)
        Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state
        """
    
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
# 确保模型当前只支持一次处理一个 token。如果批次中包含多个 token,则会抛出异常。
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
# 对输入的 hidden_states 进行线性变换(通过 in_proj 层),并去除维度 1(即批次中的 token 数量)。此时 xz 的形状为 (B, 2D),表示输入的两个部分(通常是 x 和 z)拼接在一起。
        x, z = xz.chunk(2, dim=-1)  # (B D)
# 将 xz 张量沿着最后一个维度(即特征维度)分成两个部分,得到 x 和 z,每个张量的形状为 (B, D)。

        # Conv step
        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
            conv_state[:, :, -1] = x
# 使用 torch.roll 将 conv_state 张量沿着最后一个维度(即卷积窗口维度 W)滚动一位,相当于将旧的卷积状态滑动一位,为新输入腾出空间。然后,将最新的输入 x 塞入卷积状态的最后一个位置。
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
         # dt:一个张量,形状为 (B, dt_rank),表示时间增量。
         # B 和 C:每个状态空间的两个部分,形状为 (B, d_state)。
        # Don't add dt_bias here    
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        # 对 dt 进行线性变换,得到形状为 (B, d_inner) 的张量。self.dt_proj.weight 是权重矩阵。
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step  状态空间模型(SSM)步骤,选择性状态更新
        if selective_state_update is None:
            # Discretize A and B      将 A 和 B 离散化
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))  
            # 对 dt 进行 softplus 激活。
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)  # 通过线性层 out_proj 生成最终的输出 y。
        return out.unsqueeze(1), conv_state, ssm_state
# out.unsqueeze(1):将输出 out 增加一个维度,返回形状为 (B, 1, D) 的张量,这样输出符合与输入的时间步数相同的格式。
# 返回卷积状态 conv_state 和状态空间模型状态 ssm_state,以便在下一步中继续使用。

5 forward 函数

处理输入的 hidden_states,并返回相同形状的输出。这段代码的核心是:处理输入数据、应用变换、卷积计算、状态更新以及最终输出

def forward(self, hidden_states, inference_params=None): # 主体部分
       """
        主要思路:
        1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
        2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
        3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
        4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
        5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。

        hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 D
        Returns: 返回与输入相同形状的输出
        """
        batch, seqlen, dim = hidden_states.shape
# hidden_states 是模型当前层的输入,通常来自前一层的输出。
# inference_params: 这个参数是可选的,通常用于推理阶段,用来传递状态缓存或者推理时的其他参数。如果没有推理需求,可以为 None

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            # 从缓存中加载卷积状态和状态空间模型状态。
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out
# 如果序列的偏移量大于零(表示模型在推理中已经处理了一些序列,可能是自回归解码),则跳过卷积计算,直接进入 step 函数处理当前输入,并返回输出。

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
# 使用 rearrange(hidden_states, "b l d -> d (b l)") 把 hidden_states 重新排列成形状 (D, B*L),这样可以方便地进行矩阵乘法
            "d (b l) -> b d l",
# self.in_proj.weight 是一个线性变换矩阵,将输入数据投影到新的空间,并通过矩阵乘法得到 xz,最终重排为 (B, D, L) 的形状。
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # 通过取指数得到矩阵 A。A 的形状是 (d_inner, d_state)。这个矩阵用于状态空间模型的计算
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  
            # Doesn't support outputting the states
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            # Compute short convolution
            if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out

5.1 模块内部参数与层定义

  1. 线性变换层 (in_proj, out_proj)

    • in_proj:将输入 (B, L, D) 投影到 (B, L, d_inner * 2)。这相当于将输入分为两个部分 xz,后续分别进行处理。
    • out_proj:将内部维度 (B, L, d_inner) 投影回 (B, L, D)
  2. 卷积层 conv1d
    使用 nn.Conv1dx 部分进行深度卷积 depthwise convolution。因为 groups 等于 d_inner,所以每个通道独立进行卷积,不同通道之间没有混合。这层可以捕捉局部的短程依赖。

  3. 激活函数 SiLU
    使用 SiLU 作为激活函数,为网络增加非线性能力。

  4. SSM 参数 (A, B, C, D) 与 dt

    • A_log:存储 A 矩阵中特征值的对数。当最终求 A 时使用 A = -exp(A_log),确保 A 的特征值为负数,从而保证系统稳定性。

    • D:一个可学习的向量,对输出进行额外的线性修正,有点类似于跳跃连接(skip connection)。

    • x_proj:将经过卷积和激活后的特征 x 投影到 (dt_rank + 2 * d_state) 的维度,用于得到 dtBC 参数。

    • dt_proj:将 dt 子空间映射到 d_inner 维度,并为 dt 参数添加偏置项。通过 softplus 确保 dt 始终为正值。

    • dt:控制连续时间状态转移到离散时间的转换长度。

    • BC:从输入特征中动态计算,控制状态更新方程 s(t+1) = s(t)*exp(A*dt) + x(t)*... 和输出方程 y(t) = s(t)*C + x(t)*D

5.2 前向传播 forward

forward(self, hidden_states, inference_params=None) 的功能是对整个序列进行处理或在推理模式下逐步处理。

  • inference_params is None,表明在训练或批量推理时并行处理整个序列。
  • inference_params 提供,则可能在自回归场景下每次输入一个时间步,并利用缓存的状态 (conv_state, ssm_state) 来加速计算。

下面展示前向传播基本流程:

  1. 输入投影
    in_proj(B, L, D) 投影到 (B, L, d_inner * 2),分成 xz 两部分。

  2. 计算 A 矩阵
    使用 A = -exp(A_log) 得到状态转移矩阵的隐式表示。

  3. 卷积与激活
    x 输入到卷积层中捕捉局部信息,随后通过 SiLU 激活函数。

  4. 计算 dt, B, C
    利用 x_proj(B, L, d_inner) 映射到 (B, L, dt_rank + 2*d_state),再分解为 dt, B, C

    • dt 通过 dt_proj 映射并加上偏置,经 softplus 后得到最终离散化时间步。
    • B, C 则直接用于状态更新方程。
  5. 状态空间扫描 (SSM Scan)
    使用 selective_scan_fn(若可用的话)或其他计算方式,将 (x, dt, A, B, C, D, z) 输入状态模型进行迭代更新。

    • 状态更新方程:s(t+1) = s(t)*exp(A*dt) + x(t)*B*dt
    • 输出方程:y(t) = s(t)*C + D*x(t)
    • 最终通过 y 和 z 的激活组合(如 y * act(z)) 得出中间结果。
  6. 输出投影
    最后通过 out_proj 将中间结果 (B, L, d_inner) 映射回 (B, L, D),得到最终输出。

5.3 单步推理 step

step(self, hidden_states, conv_state, ssm_state) 用于自回归场景下一步一步处理输入(例如语言模型在推断阶段每次输入一个 token)。

  • 接受当前时间步的输入 (B, 1, D)

  • 使用 in_proj 分为 xz

  • 更新卷积状态 conv_state,执行单步卷积输出。

  • 根据 x 计算出 dt, B, C 等参数。

  • 使用前一次的 ssm_state 更新到下一个状态。

  • 计算当前时间步的输出,并返回 (out, conv_state, ssm_state)

5.4 状态缓存

  • allocate_inference_cache:在推理模式下为 (conv_state, ssm_state) 分配内存。
  • _get_states_from_cache:根据 inference_params 获取或初始化该层对应的状态存储,用于分步推理时的缓存与复用。

5.4 自定义 CUDA 内核

代码中尝试导入 causal_conv1d_fnselective_scan_fnselective_state_update 等自定义高效内核。当这些内核可用时,代码可使用更高效的路径加速卷积和状态更新的计算。当不可用时,则退回使用 PyTorch 原生操作。

若有想交流的小伙伴可以私信我,或联系微信:a2744739916,欢迎大家!


Mamba 源码处:

class Mamba(nn.Module):
   # 这个实例化对象继承自 nn.Module,将构成整个Mamba模型的结构和前向传播逻辑
    def __init__(
        self,
        d_model,  # 模型隐藏层的维度
        d_state=16,   # 状态空间的维度,默认为16
        d_conv=4,  # 定义卷积核的维度,默认为4
        expand=2,   # 定义扩展因子,默认为2
        dt_rank="auto",   # 定义输入依赖步长Δ的秩,'auto'表示自动设置
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,  # 定义卷积层使用偏置项
        bias=False,  # 定义其他层(如线性层)是否使用偏置项
        use_fast_path=True,  # Fused kernel options
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)  # 计算内部维度,即扩展后的维度
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank   
        # 如果,dt_rank != 'auto',根据隐藏层维度自动计算Δ的秩
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
		
        # 输入线性变化层
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(  # 将输入x映射到状态空间模型的参数Δ、B和C
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        # 将Δ从dt_rank维度映射到d_inner维度,即输入层到隐藏层
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            # Compute short convolution
            if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out

    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
        x, z = xz.chunk(2, dim=-1)  # (B D)

        # Conv step
        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
            conv_state[:, :, -1] = x
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        # Don't add dt_bias here
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step
        if selective_state_update is None:
            # Discretize A and B
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        # ssm_dtype = torch.float32
        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None
        if self.layer_idx not in inference_params.key_value_memory_dict:
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_conv,
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,
                # dtype=torch.float32,
            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
            # TODO: What if batch size changes between generation, and we reuse the same states?
            if initialize_states:
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值