(简单易学)mamba2核心ssd算法逻辑整理(基于mamba2-minimal实现)

本文最后也有部分代码,更详细的请参考官方链接https://github.com/tommyip/mamba2-minimal/blob/main/mamba2.py这段代码实现的是 Mamba-2 模型的核心算法 Structured State Space Duality (SSD),用于高效地处理序列数据。让我们来逐步拆解这段代码,理解其背后的原理和运作机制。

1. SSD 算法概述

  • SSD 是一种基于状态空间模型的序列处理方法,其核心思想是将序列分解成若干个块 (chunk),并在块内和块间进行高效的信息传递。

  • SSD 利用了矩阵的低秩分解和指数衰减特性,将原本复杂的序列建模问题转化为一系列高效的矩阵乘法运算,从而显著降低了计算复杂度。

2. 代码解析

segsum 函数解析:

  • 功能: 计算一个特殊的累积和,用于生成一个 1-半可分矩阵(1-semiseparable matrix),该矩阵等效于一个标量状态空间模型 (SSM)。

  • 输入:

    • x: 输入张量。

    • device: 计算设备。

  • 步骤:

    1. 复制扩展: 将输入 x 沿着最后一个维度复制 T 次,生成一个新的张量。

    2. 生成掩码: 创建两个下三角掩码矩阵,分别用于控制累积和的范围。

    3. 计算累积和: 使用 torch.cumsum 函数计算累积和,并利用掩码矩阵控制计算范围。

    4. 填充负无穷: 将不在计算范围内的元素填充为负无穷,这是为了在后续计算指数时将其置零。

SSD 函数解析:

  • 输入:

    • x: 输入序列,形状为 (batch, seqlen, n_heads, d_head)。

    • A: 控制状态转移的矩阵,形状为 (batch, seqlen, n_heads)。

    • B: 将输入映射到状态空间的矩阵,形状为 (batch, seqlen, n_heads, d_state)。

    • C: 将状态空间映射到输出的矩阵,形状为 (batch, seqlen, n_heads, d_state)。

    • chunk_size: 将序列分割成的块的大小。

    • initial_states: 初始状态,可选。

    • device: 计算设备,例如 CPU 或 GPU。

  • 步骤:

    1. 数据重排 (Rearrange into chunks):

      • 将输入序列、A、B、C 矩阵按照 chunk_size 分割成若干个块。

      • 代码中使用了 rearrange 函数进行高效的数据重排操作。

    2. 块内计算 (Intra-chunk computation):

      • 计算每个块内的输出 Y_diag:

        • L = torch.exp(segsum(A, device=device)) 计算状态转移矩阵的累积效应。

        • Y_diag 通过 C, B, L 和 x 的矩阵乘法得到。

      • 计算每个块内的最终状态 states:

        • decay_states 计算状态的衰减情况。

        • states 通过 B, decay_states 和 x 的矩阵乘法得到。

    3. 块间循环 (Inter-chunk recurrence):

      • 使用 initial_states 初始化状态,或使用默认的零向量。

      • decay_chunk 计算块间的衰减情况。

      • new_states 通过 decay_chunk 和 states 的矩阵乘法得到,表示更新后的状态。

      • final_state 保存最后一个块的最终状态。

    4. 状态到输出的转换 (State-to-output conversion):

      • state_decay_out 计算状态衰减对输出的影响。

      • Y_off 通过 C, states 和 state_decay_out 的矩阵乘法得到,表示块间信息传递对输出的贡献。

    5. 输出合并 (Output combination):

      • 将块内输出 Y_diag 和块间输出 Y_off 相加,得到最终的输出 Y。

  • 输出:

    • Y: 处理后的输出序列,形状为 (batch, seqlen, n_heads, d_head)。

    • final_state: 最后一个块的最终状态。

3. 代码亮点:

  • 高效的矩阵运算: 代码大量使用了 torch.einsum 函数,这是一种高效的矩阵乘法运算方法,可以充分利用硬件资源,加速计算。

  • 并行计算: 代码中提到,步骤 1、2 和 4 可以并行计算,这为进一步提升性能提供了空间。

总结:

  • segsum 函数是一个辅助函数,用于生成 SSD 算法中所需的特殊矩阵。

  • ssd 函数实现了 SSD 算法,通过将序列分解成块并利用矩阵的低秩分解和指数衰减特性,实现了高效的序列建模。

def segsum(x: Tensor, device: Device = None) -> Tensor:
    """Stable segment sum calculation.

    `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.

    Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
    """
    T = x.size(-1)
    x = repeat(x, "... d -> ... d e", e=T)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
    x = x.masked_fill(~mask, 0)
    x_segsum = torch.cumsum(x, dim=-2)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum


def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
    """Structed State Space Duality (SSD) - the core of Mamba-2

    This is almost the exact same minimal SSD code from the blog post.

    Arguments
        x: (batch, seqlen, n_heads, d_head)
        A: (batch, seqlen, n_heads)
        B: (batch, seqlen, n_heads, d_state)
        C: (batch, seqlen, n_heads, d_state)

    Return
        y: (batch, seqlen, n_heads, d_head)

    Source
     1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
     2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
    """
    assert x.shape[1] % chunk_size == 0

    # Rearrange into chunks
    # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
    # This is not implemented and left as an exercise for the reader 😜
    x, A, B, C = [
        rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
    ]

    A = rearrange(A, "b c l h -> b h c l")
    A_cumsum = torch.cumsum(A, dim=-1)

    # 1. Compute the output for each intra-chunk (diagonal blocks)
    L = torch.exp(segsum(A, device=device))
    Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)

    # 2. Compute the state for each intra-chunk
    # (right term of low-rank factorization of off-diagonal blocks; B terms)
    decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
    states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)

    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
    # (middle term of factorization of off-diag blocks; A terms)
    if initial_states is None:
        initial_states = torch.zeros_like(states[:, :1])
    states = torch.cat([initial_states, states], dim=1)
    decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
    new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
    states, final_state = new_states[:, :-1], new_states[:, -1]

    # 4. Compute state -> output conversion per chunk
    # (left term of low-rank factorization of off-diagonal blocks; C terms)
    state_decay_out = torch.exp(A_cumsum)
    Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)

    # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
    Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")

    return Y, final_state

最后附上关于mamba2块缝合到你的模型中的缝合方法链接

(简单易学)将mamba2添加到你的模型(NLP | CV-2d)中【PyTorch】_mamba2安装-CSDN博客

  • 28
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值