Manba模型跟着源代码来深入理解他的原理

1.1 Transformer模型的优劣势

        transformer自从2017年被提出已经在各个领域大杀四方,transformer深度原理博主将不在详细讲解,我们只是来简单的分析transformer的优点和劣势。

        transformer的源结构如图1所示。整个模型框架中最重要的模块就是自注意力机制,自注意力机制的原理实际上就是分为Q K V,其中V就是Q的答案,往往K和Q是一样的,也可以是不一样的,但是K值需要和V值有相关性。Q K依次计算他们的相关性,形成一个大的二维矩阵,也即Qi Kj 之间的相关性,然后根据计算出来的相关性矩阵和V进行相乘,这样就可以得到说Vj 用 其他的值计算出他们的相关性的值,如Vj = v1 * V1 + v2 * V2 ..... 依次相加得到新的值。通过transformer的自注意力机制的方式,可以使得模型关注更多上下文之间的关联,从而实现在预判下个token时候充分考虑了上下文。虽然 transformer的自注意力方法非常有效,但是他同样也有他的缺点,缺点之一就是他的计算复杂度非常高,为O(N^2*d + N^2) 其中 d表示的是嵌入维度信息,N表示的是上下文长度。这个计算复杂度是非常之高的,虽然在训练的过程中,可以通过并行计算来加速训练的过程,但是在推理阶段却非常之慢。那有没有一种更加合理的模型来解决这些问题呢?科学家们也一直在寻找和解决transformer的问题,从而提升模型的能力。很庆幸,在2023的时候,科学家们提出了一种新型模型,该模型叫manba模型。manba模型不仅拥有比transformer更快的训练速度,同时他在推理速度上面也非常之快,并且他独特的模型结果,使得他在更长的上下文中具有更强的性能。


图1. transformer原型结构图

1.2 SSM模型,状态转移模型

        在讲Manba模型,还是需要先从SSM模型开始进行讲解。已经很很多博主都是从RNN开始讲Manba模型,但是本次我将跳过RNN模型直接从SSM模型开始讲解Manba模型,因为我觉得SSM模型是Manba模型的核心,也是直接相关模型。所以避免读者刚开始学,类似于博主这种才刚学的人来说,其实越少讲解和越直击核心的文章可能更加容易理解这些高深的模型和原理。SSM模型又称为结构化状态空间序列模型,其基本的公式核心如下所示:

        A表示的是状态转移矩阵,这里我大致讲解一下A矩阵。因为本人对于状态矩阵也不是非常理解,毕竟只是一个初学者。我们知道时序数据可以理解为一个非常高维空间的一个曲线函数。我们首先考虑一维,因为更多高维的数据其实也只是一维的拓展。我们知道傅里叶展开可以对任意的非周期性曲线进行展开。然后我们就可以得到一个大概的公式,公式如下: f(x) = sum(cn * e(i2pi * n)), 如果我们把e(i2pi * n) 看成是一个基座,那我们就得到Cn就是这个基座的值。如我们在二维坐标系中(1,2) 表示 1个x, 2个y一样。f(x)相当于变成了一个非常高维度的曲线,且他们的基座就是e(i 2 pi n), 如果我们h(t)的每一个维度都变成傅里叶展开,那么A 实际上就是函数逼近算法中的各个基座的值,如以傅里叶展开的时候,就是傅里叶级数。那么A就是对时序数据信的表达系数,也可以理解为对他在高维空间的压缩。

        因为Dx(t) 相当于一个残差链接,所以我们可以先不关注这部分数据,那么公式就变成了如下形式:

         因为上述的表达式是属于连续状态转移方程,那么我们将他转变成离散状态转移方程,其中变换公式如下图所示,所有的公式为手写推导,故全部写在纸上。

         可以看出,经过公式推导以后,我们得到离散化后的A = e(delta * A), B经过如下图所示的公式推导,零阶保持法也即U(t) 在 t -> t+1阶段保持不变,所以就变成了如下公式推导过程。

        至此我们就得到离散化以后的状态转移方程矩阵关系了。

        假设如果A B C都不变,经过其他很多博主的分享,实际上他就等同于一维卷积的过程,然后在采用并发卷积的过程就可以实现很快的分享了。但是如果A B C矩阵都不变的情况下,那么此时输入的 X 经过B C以后就不能很好的得到重点关注是什么了,也叫线性时不变系统。这里可能会比较难以理解,我也是理解了很久才理解这个概念,为什么A B C 不变就得不到重点关注呢?我们假设我们当前输入的系统值是x(t),因为B对于当前时刻的值相乘是固定的,所以对于x(t)不管是什么,此时B对于他的通过或者不通过都是一视同仁,或者如果他是一个倍乘关系则这个t时刻的x 加入到状态转移矩阵里面的值都是固定不变的,也就没有所谓的上下文关系或者重点捕捉的能力了。那么Manba模型为了解决这个问题,又引入了线性变化系统,也就是 B ,C, delta 都是和输入相关的,这样就可以起到针对不同的输入,他可以做到对指定t时刻对应的x值到底做过滤或者重点关注等等操作了。 

 2.1 跟着源码看Manba模型的原理

 

图5 manba模型图 

 

图6 manba block 

def create_block(
    d_model,    d_intermediate,    ssm_cfg=None,    attn_layer_idx=None,    attn_cfg=None,    norm_epsilon=1e-5,    rms_norm=False,    residual_in_fp32=False,    fused_add_norm=False,    layer_idx=None,    device=None,    dtype=None,):
    if ssm_cfg is None:
        ssm_cfg = {}
    if attn_layer_idx is None:
        attn_layer_idx = []
    if attn_cfg is None:
        attn_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    if layer_idx not in attn_layer_idx:
        # Create a copy of the config to modify        ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
        ssm_layer = ssm_cfg.pop("layer", "Mamba1")
        if ssm_layer not in ["Mamba1", "Mamba2"]:
            raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
        mixer_cls = partial(
            Mamba2 if ssm_layer == "Mamba2" else Mamba,            layer_idx=layer_idx,            **ssm_cfg,            **factory_kwargs
        )
    else:
        mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    if d_intermediate == 0:
        mlp_cls = nn.Identity
    else:
        mlp_cls = partial(
            GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
        )
    block = Block(
        d_model,        mixer_cls,        mlp_cls,        norm_cls=norm_cls,        fused_add_norm=fused_add_norm,        residual_in_fp32=residual_in_fp32,    )
    block.layer_idx = layer_idx
    return block

 

class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False    ):
        """        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"        This Block has a slightly different structure compared to a regular        prenorm Transformer block.        The standard block is: LN -> MHA/MLP -> Add.        [Ref: https://arxiv.org/abs/2002.04745]        Here we have: Add -> LN -> Mixer, returning both        the hidden_states (output of the mixer) and the residual.        This is purely for performance reasons, as we can fuse add and LayerNorm.        The residual needs to be provided (except for the very first block).        
        """ 
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.norm = norm_cls(dim)
        self.mixer = mixer_cls(dim)
        if mlp_cls is not nn.Identity:
            self.norm2 = norm_cls(dim)
            self.mlp = mlp_cls(dim)
        else:
            self.mlp = None        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
     def forward(
            self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
        ):
        r
        """Pass the input through the encoder layer.        Args:            hidden_states: the sequence to the encoder layer (required).            residual: hidden_states = Mixer(LN(residual))
         """       
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            hidden_states, residual = layer_norm_fn(
                hidden_states,
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
               is_rms_norm=isinstance(self.norm, RMSNorm)
            )
        hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)

        if self.mlp is not None:
            if not self.fused_add_norm:
                residual = hidden_states + residual
                hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
            else:
                hidden_states, residual = layer_norm_fn(
                    hidden_states,
                   self.norm2.weight,
                    self.norm2.bias,
                    residual=residual,
                   prenorm=True,
                   residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm2.eps,
                   is_rms_norm=isinstance(self.norm2, RMSNorm)
                )
            hidden_states = self.mlp(hidden_states)

        return hidden_states, residual
def forward(self, hidden_states, inference_params=None):
    """    
    hidden_states: (B, L, D)    Returns: same shape as hidden_states
    """ 
    batch, seqlen, dim = hidden_states.shape

    conv_state, ssm_state = None, None
    if inference_params is not None:
        conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
        if inference_params.seqlen_offset > 0:
            # The states are updated inplace
            out, _, _ = self.step(hidden_states, conv_state, ssm_state)
            return out

    # We do matmul and transpose BLH -> HBL at the same time
    xz = rearrange(
        self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),        "d (b l) -> b d l",        l=seqlen,    )
    if self.in_proj.bias is not None:
        xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
     
    # 查看离散部分的推导,可以看出A就被转换成了一个Exp(logA),为什么用logA,其实我也不知道,有大神知道的可以告知一下,谢谢
    A = -torch.exp(self.A_log.float())# (d_inner, d_state)
   # In the backward pass we write dx and dz next to each other to avoid torch.cat
   if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:
      # Doesn't support outputting the states
       out = mamba_inner_fn(
            xz,
            self.conv1d.weight,
            self.conv1d.bias,
            self.x_proj.weight,
            self.dt_proj.weight,
            self.out_proj.weight,
            self.out_proj.bias,
            A,
            None,# input-dependent B
            None,  # input-dependentC
            self.D.float(),
            delta_bias=self.dt_proj.bias.float(),
            delta_softplus=True,
      )
    else:
        x, z = xz.chunk(2, dim=1)
        # Compute short convolution
        if conv_state is not None:
            # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv            # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
            conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
        #对数据进行一次卷积
        if causal_conv1d_fn is None:
            x = self.act(self.conv1d(x)[..., :seqlen])
        else:
            assert self.activation in ["silu", "swish"]
            x = causal_conv1d_fn(
                x=x,
                weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                bias=self.conv1d.bias,
                activation=self.activation,
                )

        # We're careful here about the layout, to avoid extra transposes.        # We want dt to have d as the slowest moving dimension        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        # 将X矩阵经过线性变换成dt_rank, d_state 维度的数据,其中d_state表示的是隐状态下的维度数据
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = self.dt_proj.weight @ dt.t()
        dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        assert self.activation in ["silu", "swish"]
        y = selective_scan_fn(
            x,            dt,            A,            B,            C,            self.D.float(),            z=z,            delta_bias=self.dt_proj.bias.float(),            delta_softplus=True,            return_last_state=ssm_state is not None, 
        )
        if ssm_state is not None:
            y, last_state = y
            ssm_state.copy_(last_state)
        y = rearrange(y, "b d l -> b l d")
        out = self.out_proj(y)
    return out

        结合上面的代码,以及下面的Mamba模型图,我们可以看出 x(t)时刻的数据经过 project 线性变换以后,转换成了 Bt , deltaT, Ct 。因此B delta C 就变成了和输入强相关的变量,而其中project是不变的。

        然后状态转移矩阵通过delta t 来改变A的时变系统,这样就可以变成了整个系统就是一个时变系统。为什么A要不变呢?我个人的理解是说A相当于整个函数的不变的部分,因为状态或者我们自己人也一样在记忆的时候他的以及的逻辑和惯性应该是不变的,变得应该是指当前时刻记忆或者遗忘得范围,比如我们今天多学习了昨天得信息,那我们就更多得记住了昨天得,可能对以前得就加大了遗忘场地程度。接下来我们来看最重要的SSM系统,也即selective_scan_ref函数。 

def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,                      return_last_state=False):
    """u: r(B D L)delta: r(B D L)A: c(D N) or r(D N)B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)D: r(D)z: r(B D L)delta_bias: r(D), fp32
out: r(B D L)last_state (optional): r(B D dstate) or c(B D dstate)
"""
    dtype_in = u.dtype
    u = u.float()
    # 将 delat变成更平滑
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3is_variable_C = C.dim() >= 3if A.is_complex():
    if is_variable_B:
        B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
    if is_variable_C:
        C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
    else:
        B = B.float()
        C = C.float()
    x = A.new_zeros((batch, dim, dstate))
    ys = []
    # t时刻的delta 和 A进行相乘
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            # 每个 t 时刻的向量都和B对应t时刻的向量 和 delta t 时刻的向量进行相乘
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
    #下面就是对
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
            last_state = x
        if y.is_complex():
            y = y.real * 2    ys.append(y)
    y = torch.stack(ys, dim=2) # (batch dim L)out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
    return out if not return_last_state else (out, last_state)

        查看上面的代码在结合Mamba的原型图,我们B Delta C都是X(T)时刻相关的,也就是对应的语言模型中的第 t 个的dim向量经过project 得到的B,Delta, C。实际上Delta 就相当于门控,也就是因为他控制着A的大小,控制着X的输入情况。因为A是-exp(logA) 所以,delta 越大,意味这对影状态 遗忘的越多,也就是对上下文进行重置,同时对X当前输入的信息关注度越高。如果delta越小,意味着对保持当前的上下文,对x 的输入关注度较小。这样通过 delta的控制就可以很好的控制系统要关注上下文还是关注当前的信息。而B 控制着 x是否要输入到状态空间中,因为B = WX,对应的可能就是通过上下文的关系决定当前的信息是否进入系统中,训练到最后可能B对应的 t 时刻的值就接近于 0 或者 1?C 是控制这输出,是否输出到Y信息中。这样通过B C Delta 和Xt 时刻相关,这样就可以使得系统在不同时刻是否关注上下文和是否只关注当前信息。从而实现了更高的上下文。

        通过源码可以发现 B 不在是 离散模型中的

 

        B = dela * B 也就是省略了前面的值。为什么呢?我们仔细观察函数

        我们假设exp(delta A) = delta A + I , 因为 A是小于0的,所以 exp(delta * A) -> (0,1) 之间变化,设 X = delta * A 则 exp(x) = x + 1 当 x 在 -1 - 0 之间变化的时候,则两边是差不多相等的。只要我们控制好 delta * A 在 -1 - 0之间即可让 B = delta * B即可。

        至此我们已经将Mamba模型讲了一遍,Mamba没有采用S4模型中的将A 初始化为Hippo矩阵,而是初始化0然后进行训练,这里我不太理解为什么不采用S4模型中的hippo初始化。

我们至此仍然有两个疑问,一个是为什么要对A取对数,2 为什么不采用Hippo初始化A矩阵。对A取对数,我可能的理解就是为了让遗忘更慢一些?亦或者让delta * A 更加等于 0 - 1之间?

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值