VSSM的核心是VSSBlock,VSSBlock的核心是SS2D,因此这篇文章主要介绍SS2D块,而且仅仅是简单梳理,不涉及原理解释,其中对变量所做的旋转翻转操作,应与交叉扫描机制有关。其与标准SSM多出的K参数,应该也与交叉扫描机制有关,理解能力有限,欢迎指出错误。
文章目录
Mamba实现可以参照之前的
mamba_minimal系列
论文地址:
VMamba
论文阅读:
VMamba:视觉状态空间模型
代码地址:
https://github.com/MzeroMiko/VMamba.git
参数及简写 | VMamba论文简写 |
---|---|
batch_size b | B |
图片像素点h*w / l | L |
隐藏维度 d / d_model | |
潜在状态维度 n / d_state | N |
扩展因子 ssm_ratio | |
d_in / d_inner | D |
数据依赖步长 Δ \Delta Δ / dts | |
delta秩 dt_rank |
之后的数据尺寸以[b, d, h, w] 或者[b, 4, d_state, h*w]简单表示
class SS2D
整体结构
通过forward函数梳理SS2D块的整体结构
def forward
变量 | 尺寸说明 |
---|---|
输入x | (b, h, w, d_model) |
x经过输入映射 | (b, h, w, 2d ) |
x经过切分 | (b, h, w, d) |
z | (b, h, w, d) |
y | (b, h, w, d) |
out | (b, h, w, d_model) |
SS2D块的整体结构和Mamba_minimal中的manbablock相似,其中ssm部分更为复杂,卷积也从1D卷积变为2D卷积。这里面所谓的变量z应为门控变量,def forward_core包含了ssm部分。
操作 | 维度变换 | |
---|---|---|
in_proj | (b, h, w, d_model) -> (b, h, w, 2*d) | |
conv2d | permute (b, h, w, d) -> (b, d, h, w) | |
ssm | (b, d, h, w) -> (b, h, w, d) | |
out_proj | (b, h, w, d) -> (b, h, w, d_model) |
在卷积时,会进行转置,将d放到长宽hw前,这就是channel_first的含义,因此如果进行过卷积,channel_first为真。
def forward(self, x: torch.Tensor, **kwargs):
with_dconv = (self.d_conv > 1)
x = self.in_proj(x)
if not self.disable_z:
x, z = x.chunk(2, dim=-1) # (b, h, w, d)
if not self.disable_z_act:
z = self.act(z)
if with_dconv:
x = x.permute(0, 3, 1, 2).contiguous()
x = self.conv2d(x) # (b, d, h, w)
x = self.act(x)
y = self.forward_core(x, channel_first=with_dconv)
if not self.disable_z:
y = y * z
out = self.dropout(self.out_proj(y))
return out
初始化
def __ init__
常规初始化
下面是初始化需要定义的参数
基本维度参数 | 说明 |
---|---|
d_model | 输入的维度 |
d_state | 隐状态的维度 |
ssm_ratio | d_inner = d_model * ssm_ratio |
dt_rank | delta步长的秩 |
说明 | |
---|---|
d_conv | 2D卷积核大小 |
conv_bias | |
forward_type | 前向版本 |
initialize | 初始化版本 |
其他参数 | 说明 |
---|---|
dropout | |
bias | 偏置 |
act_layer | SiLU |
delta初始化参数 | |
---|---|
dt_min | 控制参数值范围 |
dt_max | 控制参数值范围 |
dt_init | 参数值初始化方式,如randam |
dt_scale | 用来定义初始化的std |
dt_init_floor | 控制数值稳定性 |
篇幅所限,不在这里放所有代码,具体而言排除掉前向种类标签部分的代码,排除掉之后,剩余部分还是比较清晰。首先是用ssm_ratio参数赋值d_inner,然后对dt_rank赋值。之后定义的操作如下
操作 | 作用 | 维度定义 |
---|---|---|
in_proj | 输入映射 | (d_model, 2*d_inner) |
act | 激活层 | |
conv | 2d卷积 | |
x_proj | 数据映射得到数据依赖的参数 | List[(d_inner, (dt_rank +d_state*2)) *4] |
x_proj_weight | 保存的权重参数 | (4, d_state*2+dt_rank,d_inner) |
out_proj | 输出映射 | (d_inner, d_model) |
dropout |
factory_kwargs = {"device": None, "dtype": None}
super().__init__()
d_inner = int(ssm_ratio * d_model)
dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
self.d_conv = d_conv
# in proj =======================================
d_proj = d_inner if self.disable_z else (d_inner * 2)
self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs)
self.act: nn.Module = act_layer()
# conv =======================================
if d_conv > 1:
self.conv2d = nn.Conv2d(
in_channels=d_inner,
out_channels=d_inner,
groups=d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
# x proj ============================
self.x_proj = [
nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs)
for _ in range(k_group)
]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
del self.x_proj
# out proj =======================================
self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
模型参数初始化
一共提供了三个版本v0,v1,v2。以v0为例说明
v0首先的操作有关于delta参数即dt
if initialize in ["v0"]:
# dt proj ============================
self.dt_projs = [
self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
for _ in range(k_group)
]
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner)
del self.dt_projs
# A, D =======================================
self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D)
def dt_init
dt参数初始化,以及给出delta映射,delta参数的给出:x->x_proj ->split -> dt_proj ->delta
输入x经过x_proj映射得到数据依赖的三个参数 B , C , Δ B, C,\Delta B,C,Δ,其中 Δ \Delta Δ 得到的维度是dt_rank,还需要进行一个(dt_rank, d_inner)的线性映射
参数 | 解释 |
---|---|
dt_rank | delta参数的秩 |
d_inner | |
dt_scale | 用来定义初始化的std |
dt_init | dt_proj的初始化方式 |
dt_min | 控制参数值范围 |
dt_max | 控制参数值范围 |
d t = e α ( ⋅ l o g ( d t _ m a x ) − l o g ( d t _ m i n ) ) + l o g ( d t _ m i n ) dt = e^{\alpha (\cdot log(dt\_max) - log(dt\_min)) + log(dt\_min)} dt=eα(⋅log(dt_max)−log(dt_min))+log(dt_min)
其中 α \alpha α属于0到1的均匀分布,因此 d t dt dt的取值为 e l o g d t _ m i n e^{log{dt\_min}} elogdt_min到 e l o g d t _ m a x e^{log{dt\_max}} elogdt_max。即 d t _ m i n dt\_min dt_min到 d t _ m a x dt\_max dt_max
softplus函数为 S o f t p l u s ( x ) = 1 β ∗ l o g ( 1 + e x p ( β ∗ x ) ) Softplus(x) = \frac{1}{\beta} \ast log(1+exp(\beta \ast x)) Softplus(x)=β1∗log(1+exp(β∗x))
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
# dt_proj.bias._no_reinit = True
return dt_proj
def A_log_init
观测矩阵A的初始化
参数 | 说明 |
---|---|
copies | 在初始化版本v0情况下为k=4 |
merge | 在初始化版本v0情况下为True |
无copy的A_log的维度 ( d _ i n n e r , d _ s t a t e ) (d\_inner, d\_state) (d_inner,d_state),copy后的维度 ( 4 , d _ i n n e r , d _ s t a t e ) (4, d\_inner,d\_state) (4,d_inner,d_state)
merge后的A_log的维度 ( 4 ∗ d _ i n n e r , d _ s t a t e ) (4*d\_inner, d\_state) (4∗d_inner,d_state)
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 0:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
def D_init
D矩阵的初始化
def D_init(d_inner, copies=-1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 0:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
前向传播
前向传播的模式即forward_core一共有两种v0和v2。
forward主函数见整体结构部分
def forward_corev0
维度参数 | 说明 |
---|---|
B | batch_size |
D | d_inner |
H | 图片高 |
W | 图片宽 |
N | d_state |
K | 在这里为4 |
R | dt_rank |
L | H * W |
中间变量 | 维度 | 说明 |
---|---|---|
输入x | [b, d, h, w] | |
x_hwwh | [b, 2, d, l] | 见下 |
xs | [b , 4, d, l] | |
x_db1 | [b, 4, dt_rank + 2 * d_state] | 切分得到数据依赖变量 |
dts | [b, 4, dt_rank, h*w] | delta矩阵 |
Bs | [b, 4, d_state, h*w] | B矩阵 |
Cs | [b, 4, d_state, h*w] | C矩阵 |
As | [4*d, d_state] | A矩阵 |
Ds | [4*d] | D矩阵 |
x_hwwh
这个的形状变换比较复杂
A = x.view(B, -1, L) : [b, d, h*w]
B = torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)]:[b, d, w*h]
torch.stack(A, B, dim = 1): [b , 2d, l] l = h*w = w *h
torch.stack(A, B, dim = 1) .view(B, 2, -1, L) :[b, 2, d, l]
其实在过程中我们就可以发现这个变量代表什么,x_hwwh代表一张高宽为hw的图片同其高宽反转后的wh图片在dim = 1方向上的堆叠
输出 | 说明 |
---|---|
out_y | [b, 4, d, l] |
inv_y | [b, 2, d, l] |
wh_y | [b, 1, h * w] |
invwh_y | [b, 1, h * w] |
y | [b, d, h* w] -> [b, l, d] |
输出y | [b , h, w, d] |
def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):
return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False)
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, D, H, W = x.shape
D, N = self.A_logs.shape
K, D, R = self.dt_projs_weight.shape
L = H * W
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float() # (b, k, d_state, l)
Cs = Cs.float() # (b, k, d_state, l)
As = -torch.exp(self.A_logs.float()) # (k * d, d_state)
Ds = self.Ds.float() # (k * d)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = selective_scan(
xs, dts,
As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
y = self.out_norm(y).view(B, H, W, -1)
return (y.to(x.dtype) if to_dtype else y)
def forward_corev2
v2版本把主要操作放在了扫描部分cross_selective_scan完成
def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scan, force_fp32=None):
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
x = cross_selective_scan(
x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
self.A_logs, self.Ds, delta_softplus=True,
out_norm=getattr(self, "out_norm", None),
out_norm_shape=getattr(self, "out_norm_shape", "v0"),
force_fp32=force_fp32,
SelectiveScan=SelectiveScan,
)
return x
扫描
扫描是构成SSM的重要部分
交叉扫描
def cross_selective_scan
这一部分由forward版本v2调用
真正的扫描部分,由下面函数调用selective_scan函数,而selective_scan又调用SelectiveScanOflex类,由其中导入的selective_scan_cuda_oflex模块完成扫描
其中的变量和维度定义和def forward_corev0中一致,ys维度为 [b, 4, d, h, w] cross_merge拼接后的维度为
def cross_selective_scan(
x: torch.Tensor=None,
x_proj_weight: torch.Tensor=None,
x_proj_bias: torch.Tensor=None,
dt_projs_weight: torch.Tensor=None,
dt_projs_bias: torch.Tensor=None,
A_logs: torch.Tensor=None,
Ds: torch.Tensor=None,
delta_softplus = True,
out_norm: torch.nn.Module=None,
out_norm_shape="v0",
# ==============================
to_dtype=True, # True: final out to dtype
force_fp32=False, # True: input fp32
# ==============================
nrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable;
backnrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable;
ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore
# ==============================
SelectiveScan=None,
CrossScan=CrossScan,
CrossMerge=CrossMerge,
):
# out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);...
B, D, H, W = x.shape
D, N = A_logs.shape
K, D, R = dt_projs_weight.shape
L = H * W
if nrows == 0:
if D % 4 == 0:
nrows = 4
elif D % 3 == 0:
nrows = 3
elif D % 2 == 0:
nrows = 2
else:
nrows = 1
if backnrows == 0:
if D % 4 == 0:
backnrows = 4
elif D % 3 == 0:
backnrows = 3
elif D % 2 == 0:
backnrows = 2
else:
backnrows = 1
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)
xs = CrossScan.apply(x)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
if x_proj_bias is not None:
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state)
Bs = Bs.contiguous()
Cs = Cs.contiguous()
Ds = Ds.to(torch.float) # (K * c)
delta_bias = dt_projs_bias.view(-1).to(torch.float)
if force_fp32:
xs = xs.to(torch.float)
dts = dts.to(torch.float)
Bs = Bs.to(torch.float)
Cs = Cs.to(torch.float)
ys: torch.Tensor = selective_scan(
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
).view(B, K, -1, H, W)
y: torch.Tensor = CrossMerge.apply(ys)
if out_norm_shape in ["v1"]: # (B, C, H, W)
y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C)
else: # (B, L, C)
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
y = out_norm(y).view(B, H, W, -1)
return (y.to(x.dtype) if to_dtype else y)
class CrossScan
自定义求导过程
变量 | 维度 |
---|---|
输入x | [b, d, h, w] |
xs | [b, 4, d ,l] |
xs[:, 0] | [b, d, h*w] |
xs[:, 1] | [b, d, w*h] |
xs[:,2:4] | [b, 2, d, l] |
backward函数的输出维度是(b, 4 ,d, l)和xs对应
class CrossScan(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 4, C, H * W))
xs[:, 0] = x.flatten(2, 3)
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
return y.view(B, -1, H, W)
拼接
交叉拼接
变量 | 维度 |
---|---|
ys | [b, 4, d, h, w] |
ys[:, 0:2] | [b, 2, d, l] |
ys[:, 2:4] | [b, 2, d, l] |
y | [b, d, l] |
class CrossMerge(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, D, H, W = ys.shape
ctx.shape = (H, W)
ys = ys.view(B, K, D, -1)
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape
B, C, L = x.shape
xs = x.new_empty((B, 4, C, L))
xs[:, 0] = x
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
xs = xs.view(B, 4, C, H, W)
return xs