VMamba:SS2D的实现

VSSM的核心是VSSBlock,VSSBlock的核心是SS2D,因此这篇文章主要介绍SS2D块,而且仅仅是简单梳理,不涉及原理解释,其中对变量所做的旋转翻转操作,应与交叉扫描机制有关。其与标准SSM多出的K参数,应该也与交叉扫描机制有关,理解能力有限,欢迎指出错误。

在这里插入图片描述


Mamba实现可以参照之前的
mamba_minimal系列
论文地址:
VMamba
论文阅读:
VMamba:视觉状态空间模型
代码地址:
https://github.com/MzeroMiko/VMamba.git

参数及简写VMamba论文简写
batch_size bB
图片像素点h*w / lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 ssm_ratio
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / dts
delta秩 dt_rank

之后的数据尺寸以[b, d, h, w] 或者[b, 4, d_state, h*w]简单表示

class SS2D

整体结构

通过forward函数梳理SS2D块的整体结构

def forward

变量尺寸说明
输入x(b, h, w, d_model)
x经过输入映射(b, h, w, 2d )
x经过切分(b, h, w, d)
z(b, h, w, d)
y(b, h, w, d)
out(b, h, w, d_model)

SS2D块的整体结构和Mamba_minimal中的manbablock相似,其中ssm部分更为复杂,卷积也从1D卷积变为2D卷积。这里面所谓的变量z应为门控变量,def forward_core包含了ssm部分。

在这里插入图片描述

操作维度变换
in_proj(b, h, w, d_model) -> (b, h, w, 2*d)
conv2dpermute (b, h, w, d) -> (b, d, h, w)
ssm(b, d, h, w) -> (b, h, w, d)
out_proj(b, h, w, d) -> (b, h, w, d_model)

在卷积时,会进行转置,将d放到长宽hw前,这就是channel_first的含义,因此如果进行过卷积,channel_first为真。

def forward(self, x: torch.Tensor, **kwargs):
        with_dconv = (self.d_conv > 1)
        x = self.in_proj(x)
        if not self.disable_z:
            x, z = x.chunk(2, dim=-1) # (b, h, w, d)
            if not self.disable_z_act:
                z = self.act(z)
        if with_dconv:
            x = x.permute(0, 3, 1, 2).contiguous()
            x = self.conv2d(x) # (b, d, h, w)
        x = self.act(x)
        y = self.forward_core(x, channel_first=with_dconv)
        if not self.disable_z:
            y = y * z
        out = self.dropout(self.out_proj(y))
        return out

初始化

def __ init__

常规初始化

下面是初始化需要定义的参数

基本维度参数说明
d_model输入的维度
d_state隐状态的维度
ssm_ratiod_inner = d_model * ssm_ratio
dt_rankdelta步长的秩
说明
d_conv2D卷积核大小
conv_bias
forward_type前向版本
initialize初始化版本
其他参数说明
dropout
bias偏置
act_layerSiLU
delta初始化参数
dt_min控制参数值范围
dt_max控制参数值范围
dt_init参数值初始化方式,如randam
dt_scale用来定义初始化的std
dt_init_floor控制数值稳定性

篇幅所限,不在这里放所有代码,具体而言排除掉前向种类标签部分的代码,排除掉之后,剩余部分还是比较清晰。首先是用ssm_ratio参数赋值d_inner,然后对dt_rank赋值。之后定义的操作如下

操作作用维度定义
in_proj输入映射(d_model, 2*d_inner)
act激活层
conv2d卷积
x_proj数据映射得到数据依赖的参数List[(d_inner, (dt_rank +d_state*2)) *4]
x_proj_weight保存的权重参数(4, d_state*2+dt_rank,d_inner)
out_proj输出映射(d_inner, d_model)
dropout
		factory_kwargs = {"device": None, "dtype": None}
        super().__init__()
        d_inner = int(ssm_ratio * d_model)
        dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
        self.d_conv = d_conv
        
        
         # in proj =======================================
        d_proj = d_inner if self.disable_z else (d_inner * 2)
        self.in_proj = nn.Linear(d_model, d_proj, bias=bias, 	**factory_kwargs)
        self.act: nn.Module = act_layer()
         # conv =======================================
        if d_conv > 1:
            self.conv2d = nn.Conv2d(
                in_channels=d_inner,
                out_channels=d_inner,
                groups=d_inner,
                bias=conv_bias,
                kernel_size=d_conv,
                padding=(d_conv - 1) // 2,
                **factory_kwargs,
            )
          # x proj ============================
        self.x_proj = [
           nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs)
           for _ in range(k_group)
        ]
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
        del self.x_proj
            
         # out proj =======================================
        self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()

模型参数初始化

一共提供了三个版本v0,v1,v2。以v0为例说明

v0首先的操作有关于delta参数即dt

		if initialize in ["v0"]:
            # dt proj ============================
            self.dt_projs = [
                self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
                for _ in range(k_group)
            ]
            self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank)
            self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner)
            del self.dt_projs
            
            # A, D =======================================
            self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
            self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D)

def dt_init

dt参数初始化,以及给出delta映射,delta参数的给出:x->x_proj ->split -> dt_proj ->delta

输入x经过x_proj映射得到数据依赖的三个参数 B , C , Δ B, C,\Delta B,C,Δ,其中 Δ \Delta Δ 得到的维度是dt_rank,还需要进行一个(dt_rank, d_inner)的线性映射

参数解释
dt_rankdelta参数的秩
d_inner
dt_scale用来定义初始化的std
dt_initdt_proj的初始化方式
dt_min控制参数值范围
dt_max控制参数值范围

d t = e α ( ⋅ l o g ( d t _ m a x ) − l o g ( d t _ m i n ) ) + l o g ( d t _ m i n ) dt = e^{\alpha (\cdot log(dt\_max) - log(dt\_min)) + log(dt\_min)} dt=eα(log(dt_max)log(dt_min))+log(dt_min)

其中 α \alpha α属于0到1的均匀分布,因此 d t dt dt的取值为 e l o g d t _ m i n e^{log{dt\_min}} elogdt_min e l o g d t _ m a x e^{log{dt\_max}} elogdt_max。即 d t _ m i n dt\_min dt_min d t _ m a x dt\_max dt_max

softplus函数为 S o f t p l u s ( x ) = 1 β ∗ l o g ( 1 + e x p ( β ∗ x ) ) Softplus(x) = \frac{1}{\beta} \ast log(1+exp(\beta \ast x)) Softplus(x)=β1log(1+exp(βx))

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        # dt_proj.bias._no_reinit = True
        
        return dt_proj

def A_log_init

观测矩阵A的初始化

参数说明
copies在初始化版本v0情况下为k=4
merge在初始化版本v0情况下为True

无copy的A_log的维度 ( d _ i n n e r , d _ s t a t e ) (d\_inner, d\_state) (d_inner,d_state),copy后的维度 ( 4 , d _ i n n e r , d _ s t a t e ) (4, d\_inner,d\_state) (4,d_inner,d_state)

merge后的A_log的维度 ( 4 ∗ d _ i n n e r , d _ s t a t e ) (4*d\_inner, d\_state) (4d_inner,d_state)

    @staticmethod
    def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 0:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

def D_init

D矩阵的初始化

    def D_init(d_inner, copies=-1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 0:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

前向传播

前向传播的模式即forward_core一共有两种v0和v2。

forward主函数见整体结构部分

def forward_corev0

维度参数说明
Bbatch_size
Dd_inner
H图片高
W图片宽
Nd_state
K在这里为4
Rdt_rank
LH * W
中间变量维度说明
输入x[b, d, h, w]
x_hwwh[b, 2, d, l]见下
xs[b , 4, d, l]
x_db1[b, 4, dt_rank + 2 * d_state]切分得到数据依赖变量
dts[b, 4, dt_rank, h*w]delta矩阵
Bs[b, 4, d_state, h*w]B矩阵
Cs[b, 4, d_state, h*w]C矩阵
As[4*d, d_state]A矩阵
Ds[4*d]D矩阵

x_hwwh

这个的形状变换比较复杂

A = x.view(B, -1, L) : [b, d, h*w]

B = torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)]:[b, d, w*h]

torch.stack(A, B, dim = 1): [b , 2d, l] l = h*w = w *h

torch.stack(A, B, dim = 1) .view(B, 2, -1, L) :[b, 2, d, l]

其实在过程中我们就可以发现这个变量代表什么,x_hwwh代表一张高宽为hw的图片同其高宽反转后的wh图片在dim = 1方向上的堆叠

输出说明
out_y[b, 4, d, l]
inv_y[b, 2, d, l]
wh_y[b, 1, h * w]
invwh_y[b, 1, h * w]
y[b, d, h* w] -> [b, l, d]
输出y[b , h, w, d]
def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):
        def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):
            return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False)

        if not channel_first:
            x = x.permute(0, 3, 1, 2).contiguous()
        B, D, H, W = x.shape
        D, N = self.A_logs.shape
        K, D, R = self.dt_projs_weight.shape
        L = H * W

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float() # (b, k, d_state, l)
        Cs = Cs.float() # (b, k, d_state, l)
        
        As = -torch.exp(self.A_logs.float()) # (k * d, d_state)
        Ds = self.Ds.float() # (k * d)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
        ).view(B, K, -1, L)


        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
        y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
        y = self.out_norm(y).view(B, H, W, -1)

        return (y.to(x.dtype) if to_dtype else y)

def forward_corev2

v2版本把主要操作放在了扫描部分cross_selective_scan完成

def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scan, force_fp32=None):
        if not channel_first:
            x = x.permute(0, 3, 1, 2).contiguous()
        x = cross_selective_scan(
            x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
            self.A_logs, self.Ds, delta_softplus=True,
            out_norm=getattr(self, "out_norm", None),
            out_norm_shape=getattr(self, "out_norm_shape", "v0"),
            force_fp32=force_fp32,
            SelectiveScan=SelectiveScan,
        )
        return x

扫描

扫描是构成SSM的重要部分

交叉扫描

def cross_selective_scan

这一部分由forward版本v2调用

真正的扫描部分,由下面函数调用selective_scan函数,而selective_scan又调用SelectiveScanOflex类,由其中导入的selective_scan_cuda_oflex模块完成扫描

其中的变量和维度定义和def forward_corev0中一致,ys维度为 [b, 4, d, h, w] cross_merge拼接后的维度为

def cross_selective_scan(
    x: torch.Tensor=None, 
    x_proj_weight: torch.Tensor=None,
    x_proj_bias: torch.Tensor=None,
    dt_projs_weight: torch.Tensor=None,
    dt_projs_bias: torch.Tensor=None,
    A_logs: torch.Tensor=None,
    Ds: torch.Tensor=None,
    delta_softplus = True,
    out_norm: torch.nn.Module=None,
    out_norm_shape="v0",
    # ==============================
    to_dtype=True, # True: final out to dtype
    force_fp32=False, # True: input fp32
    # ==============================
    nrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable;
    backnrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable;
    ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore
    # ==============================
    SelectiveScan=None,
    CrossScan=CrossScan,
    CrossMerge=CrossMerge,
):
    # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);...

    B, D, H, W = x.shape
    D, N = A_logs.shape
    K, D, R = dt_projs_weight.shape
    L = H * W

    if nrows == 0:
        if D % 4 == 0:
            nrows = 4
        elif D % 3 == 0:
            nrows = 3
        elif D % 2 == 0:
            nrows = 2
        else:
            nrows = 1
        
    if backnrows == 0:
        if D % 4 == 0:
            backnrows = 4
        elif D % 3 == 0:
            backnrows = 3
        elif D % 2 == 0:
            backnrows = 2
        else:
            backnrows = 1

    def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
        return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)
    
    xs = CrossScan.apply(x)
    
    x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
    if x_proj_bias is not None:
        x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
    dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
    dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
    xs = xs.view(B, -1, L)
    dts = dts.contiguous().view(B, -1, L)
    As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state)
    Bs = Bs.contiguous()
    Cs = Cs.contiguous()
    Ds = Ds.to(torch.float) # (K * c)
    delta_bias = dt_projs_bias.view(-1).to(torch.float)

    if force_fp32:
        xs = xs.to(torch.float)
        dts = dts.to(torch.float)
        Bs = Bs.to(torch.float)
        Cs = Cs.to(torch.float)

    ys: torch.Tensor = selective_scan(
        xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
    ).view(B, K, -1, H, W)
    
    y: torch.Tensor = CrossMerge.apply(ys)

    if out_norm_shape in ["v1"]: # (B, C, H, W)
        y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C)
    else: # (B, L, C)
        y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
        y = out_norm(y).view(B, H, W, -1)

    return (y.to(x.dtype) if to_dtype else y)
class CrossScan

自定义求导过程

变量维度
输入x[b, d, h, w]
xs[b, 4, d ,l]
xs[:, 0][b, d, h*w]
xs[:, 1][b, d, w*h]
xs[:,2:4][b, 2, d, l]

backward函数的输出维度是(b, 4 ,d, l)和xs对应

class CrossScan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        xs = x.new_empty((B, 4, C, H * W))
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        return y.view(B, -1, H, W)

拼接

交叉拼接

变量维度
ys[b, 4, d, h, w]
ys[:, 0:2][b, 2, d, l]
ys[:, 2:4][b, 2, d, l]
y[b, d, l]
class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        xs = x.new_empty((B, 4, C, L))
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        xs = xs.view(B, 4, C, H, W)
        return xs
  • 52
    点赞
  • 94
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值