题目:VMamba: Visual State Space Model(视觉状态空间模型)
论文:[2401.10166] VMamba: Visual State Space Model (arxiv.org)
目录
3.2. 2D-Selective-Scan for Vision Data (SS2D,2D 选择扫描)
3.3. Accelerating VMamba(加速 VMamba)
一、摘要
研究背景:在计算机视觉领域,设计计算效率高的网络架构一直是一种需要。
主要工作:本文将状态空间语言模型 Mamba 移植到 VMamba 中,VMamba 是一种线性时间复杂度的视觉骨干。
- VMamba 的核心是 视觉状态空间(VSS)模块 的堆叠 和 2D选择性扫描 (SS2D) 模块。
- 通过沿4条扫描路线遍历,SS2D 有助于弥合 1D 选择性扫描的有序性质与 2D 视觉数据的非顺序结构之间的差距,有利于从不同来源和角度收集上下文信息。
- 基于VSS模块,开发了一系列VMamba架构,并通过一系列架构和实现的增强来加速它们。
实验效果:广泛的实验展示了VMamba在各种视觉感知任务中的良好性能,突出了其与现有基准模型相比在输入缩放效率方面的优势。
二、引言
ViT概述(优缺点): ViT通常在大规模数据上表现出上级的学习能力。然而,自注意的二次复杂性导致受限。在涉及大空间分辨率的下游任务中会引入了相当大的计算开销。
研究现状:为了应对这一挑战,许多方法已经做出了相当大的努力来提高注意力计算的效率。然而,现有的方法要么对有效感受野的大小施加限制,要么在不同的任务中经历明显的性能下降。
Mamba概述:最近,在自然语言处理(NLP)领域,Mamba 是一种新的状态空间模型(SSM),已经成为具有线性复杂度的长序列建模的一种非常有前途的方法。
研究问题:然而,Mamba的核心算法 —— 并行化选择扫描操作本质上是为处理一维序列数据而设计的。当试图将其用于处理视觉数据时,这提出了一个挑战,因为视觉数据天生缺乏视觉组件的顺序排列。
主要工作:为了解决这个问题,本文提出了一种 面向空间域遍历的四路扫描机制 SS2D (2D Selective Scan) 。与自注意力机制相比,SS2D 确保每个图像 patch 通过沿着相应的扫描路径计算的压缩隐藏状态独占地获得上下文知识,从而将计算复杂度从二次降低到线性。
实验效果:具体而言,在 ImageNet-1K 上,VMamba-Base 的准确率达到了83.9%,排在第一位,超过 Swin +0.4%,吞吐量大大超过 Swin 40%(646对458)。Vmamba 的优越性扩展到各种下游任务,其中 Vmamba-Tiny/Small/Base 在 COCO 上的目标检测中实现了 47.3%/48.7%/49.2% 的 mAP(1×训练计划)。这分别超过 Swin 4.6%/3.9%/2.3% 和 ConvNeXt 3.1%/3.3%/2.2% 的表现。对于 ADE20K 上的单尺度语义分割,VMamba-Tiny/Small/Base 实现了 47.9%/50.6%/51.0% mIoU,分别超过 Swin 3.4%/3.0%/2.9% 和 ConvNeXt 1.9%/1.9%/1.9% 。此外,基于 ViT 的模型的计算复杂性会随着输入 token 的数量呈二次增长,而与此不同,VMamba 在保持相当性能的同时,FLOP 呈线性增长。
主要贡献:
- 提出了 VMamba,一个基于SSM的视觉骨干网络,用于具有 线性时间复杂度 的视觉表示学习。为了提高VMamba的推理速度,在体系结构设计和实现细节上进行了一系列改进。
- 引入了 二维选择性扫描(SS2D)来弥补一维阵列扫描和二维平面遍历之间的差距,便于将选择性扫描扩展到处理视觉数据。
- VMamba 在一系列视觉任务中表现出了良好的性能,包括图像分类、对象检测和语义分割。它还表现出对输入序列长度的显著适应性,显示出计算复杂度的线性增长。
三、方法
3.1. Network Architecture
流程:输入图像 首先被 stem模块分割成 patch,形成空间维度为 H/4 × W/4 的 2D 特征图。随后,采用多个网络阶段来创建 H/8 × W/8、H/16 × W/16 和 H/32 × W/32 分辨率的分层表示。每个阶段包括一个下采样层 (除了第一阶段),然后是 视觉状态空间(VSS)模块 的堆叠。(Stem模块 + 下采样层 + VSS模块)
VSS模块:为了进一步提高计算效率,消除了整个乘法分支 (图中红色框),因为门控机制的效果是通过SS2D的选择性实现的。因此,得到的 VSS 模块 (如图(d)所示) 由带有两个残差模块的单个网络分支组成,模仿了普通Transformer模块的体系结构。
3.2. 2D-Selective-Scan for Vision Data (SS2D,2D 选择扫描)
问题:虽然 S6 中的扫描操作的顺序性质与涉及时间数据的 NLP 任务很好地一致,但当应用于视觉数据时,其提出了显著的挑战,视觉数据本质上是非顺序的并且包含空间信息(例如,局部纹理和全局结构)。
解决方法:坚持选择性扫描方法进行输入处理,并提出了 2D 选择性扫描(SS2D)模块,以使 S6 适应视觉数据,而不损害其优势。
过程:给定输入数据,SS2D 首先 沿着四个不同的遍历路径 将输入 patch 展开为序列(即,交叉扫描),使用单独的S6模块并行地处理每个 patch 序列,并随后对所得到的序列进行整形和合并以形成输出图(即,交叉合并)。通过采用互补的1D遍历路径,SS2D 使得图像中的每个像素能够有效地整合来自不同方向上的所有其他像素的信息,从而有助于在2D空间中建立全局感受野。
如下图所示,SS2D 中的 data forwarding (前向)包括三个步骤:cross-scan(交叉扫描)、selective scanning(基于S6模块的选择性扫描) 和 cross-merge(交叉合并)。
3.3. Accelerating VMamba(加速 VMamba)
通过ImageNet-1 K上的图像分类来评估模型。每个渐进式改进的影响总结如下,其中(%,img/s)分别表示ImageNet-1 K上的 top-1准确度 和 推理吞吐量 的增益。如下所示:
- 步骤(a) (+0.0%,+41 img/s) 通过在 Triton 中重新实现 交叉扫描和交叉合并 。
- 步骤(b) (+0.0%,−3 img/s) 通过 调整 CUDA 实现的选择性扫描 ,以适应 float16 输入和 float32 输出。这显著提高了训练效率 (吞吐量从165到184),尽管在测试时速度略有波动。
- 步骤(c) (+0.0%,+174 img/s) 将选择性扫描中相对缓慢的 einsum 替换为线性变换 (即torch.nn.function.linear)。采用 (B, C, H, W) 张量布局来消除不必要的数据排列。
- 步骤(d) (−0.6%,+175 img/s) 由于其计算效率,将 MLP 引入 VMamba。丢弃了DWConv层,并将层配置从 [2,2,9,2] 更改为 [2,2,2,2],以降低FLOPs。
- 步骤(e) (+0.6%, +366 img/s) 丢弃整个乘法支路,并将参数 ssm-ratio(特征扩展因子) 从2.0降低到1.0。这允许层数提高到 [2,2,5,2],同时减少FLOPs。(证明所提出的VMamba的体系架构的有效性)
- 步骤(f) (+0.3%, +161 img/s) 通过将 参数 d_state (SSM状态维度) 从16.0减少到1.0。这允许ssm-ratio提高到2.0,并引入DWConv层,而不会增加FLOPs。
- 步骤(g) (+0.1%, +346 img/s) 将ssm-ratio降低到1.0,同时将层配置从[2,2,5,2]更改为[2,2,8,2]。
四、实验
五、代码实现
VMamba-main/classification/models/vmamba.py
- 1) 初始化 delta 投影参数
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
r""" 初始化 delta 投影参数
Args:
dt_rank:输入特征的维度
d_inner:输出特征的维度
dt_scale:初始化标准差的缩放因子,默认值为 1.0
dt_init:初始化方法,可以是 "constant" 或 "random"
dt_min 和 dt_max:初始化偏置时的最小和最大值
dt_init_floor:偏置的下限,防止过小的值
"""
# 创建线性层
dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
# 根据选择的初始化方法,使用常量或随机均匀分布来设置线性层权重
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# 初始化偏置
dt = torch.exp(
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# 计算偏置的逆 Softplus; Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
# 更新偏置值
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
# dt_proj.bias._no_reinit = True
return dt_proj
- 2) 初始化一个参数张量 A
A: (D,N)← Parameter
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
r""" 初始化一个参数张量 A_log (矩阵A)
Arg:
d_state:状态维度,表示生成的张量的长度
d_inner:内层维度,表示输出张量的行数
copies:用于重复张量的数量,默认为 -1 表示不重复
device:指定张量存储的设备(如 CPU 或 GPU)
merge:布尔值,决定是否将多个副本合并为一个张量
"""
# 创建张量 A, (d_inner, d_state)
A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
# 离散化, 计算 A 的对数
A_log = torch.log(A) # Keep A_log in fp32
# 是否对 A_log 进行扩展; (d_inner, d_state) -> (copies, d_inner, d_state) -> (copies * d_inner, d_state)
if copies > 0:
A_log = A_log[None].repeat(copies, 1, 1).contiguous()
if merge:
A_log = A_log.flatten(0, 1)
# 将 A_log 转换为可训练参数
A_log = nn.Parameter(A_log)
# 将 A_log 设置为无权重衰减
A_log._no_weight_decay = True
return A_log
- 3) 初始化一个参数张量 D
def D_init(d_inner, copies=-1, device=None, merge=True):
r""" 初始化一个参数张量 D (矩阵D)
Arg:
d_inner:内层维度,表示生成的张量的大小
copies:用于重复张量的数量,默认为 -1,表示不重复
device:指定张量存储的设备(如 CPU 或 GPU)
merge:布尔值,决定是否将多个副本合并为一个张量
"""
D = torch.ones(d_inner, device=device)
if copies > 0:
D = D[None].repeat(copies, 1).contiguous()
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
- 4) 初始化 delta投影, A, D 矩阵 (总)
def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
r""" 初始化delta投影, A, D矩阵
Arg:
d_state:状态维度,可能用于描述模型的状态空间
dt_rank:降维或投影的秩
d_inner:内层维度,通常用于网络的内部表示
dt_scale、dt_init、dt_min、dt_max、dt_init_floor:这些参数用于控制投影的初始化
k_group:组的数量,默认为 4,用于重复初始化
"""
# dt proj ============================ delta 投影
dt_projs = [
cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
for _ in range(k_group)
]
# delta 投影权重和偏置
dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
del dt_projs
# A, D =======================================
A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
return A_logs, Ds, dt_projs_weight, dt_projs_bias
- 5) VMamba_v0 版本
# support: v0, v0seq
class SS2Dv0:
def __initv0__(
self,
# basic dims ===========
d_model=96,
d_state=16,
ssm_ratio=2.0,
dt_rank="auto",
# ======================
dropout=0.0,
# ======================
seq=False,
force_fp32=True,
**kwargs,
):
r""" V-Mamba-v0 框架
Arg:
d_model: 模型的输出维度(默认为96)。
d_state: 状态维度(默认为16)。
ssm_ratio: 状态维度与模型维度的比率(默认为2.0)。
dt_rank: 动态时间参数的维度,默认为“auto”,会根据 d_model 计算
"""
if "channel_first" in kwargs:
assert not kwargs["channel_first"]
act_layer = nn.SiLU
dt_min = 0.001
dt_max = 0.1
dt_init = "random"
dt_scale = 1.0
dt_init_floor = 1e-4
bias = False
conv_bias = True
d_conv = 3
k_group = 4
factory_kwargs = {"device": None, "dtype": None}
super().__init__()
d_inner = int(ssm_ratio * d_model)
dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
self.forward = self.forwardv0
if seq:
self.forward = partial(self.forwardv0, seq=True)
if not force_fp32:
self.forward = partial(self.forwardv0, force_fp32=False)
# in proj ============================
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
self.act: nn.Module = act_layer()
self.conv2d = nn.Conv2d(
in_channels=d_inner,
out_channels=d_inner,
groups=d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
# x proj ============================
self.x_proj = [
nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
for _ in range(k_group)
]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
del self.x_proj
# dt proj, A, D ============================
self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
)
# out proj =======================================
self.out_norm = nn.LayerNorm(d_inner)
self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
x = self.in_proj(x)
x, z = x.chunk(2, dim=-1) # (b, h, w, d)
z = self.act(z)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.conv2d(x) # (b, d, h, w)
x = self.act(x)
selective_scan = partial(selective_scan_fn, backend="mamba") # 选择性扫描(加速)
B, D, H, W = x.shape
D, N = self.A_logs.shape
K, D, R = self.dt_projs_weight.shape
L = H * W
""" 四个不同的遍历路径 """
# 堆叠输入张量 x 的两个视角(原始和转置), [b, 2, d, l]
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
# 拼接 x_hwwh 和 其翻转
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
# 将 xs 通过权重矩阵 self.x_proj_weight 进行投影
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
if hasattr(self, "x_proj_bias"):
x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
""" x --投影-> delta, B, C 矩阵 """
# 由投影后的x分别得到 delta, B, C 矩阵, '(B, L, D) -> (B, L, N)'
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
# 将 dts(delta) 通过权重矩阵 self.dt_projs_weight 进行投影, '(B, L, N) -> (B, L, D)'
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) # 保证 delta, B, C 在内存中的连续(加速计算)
Bs = Bs.contiguous() # (b, k, d_state, l)
Cs = Cs.contiguous() # (b, k, d_state, l)
As = -self.A_logs.float().exp() # (k * d, d_state)
Ds = self.Ds.float() # (k * d)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
# assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
# assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
if force_fp32:
xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
if seq:
out_y = []
for i in range(4):
""" 选择性扫描 """
yi = selective_scan(
xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
delta_bias=dt_projs_bias.view(K, -1)[i],
delta_softplus=True,
).view(B, -1, L) # 在 selective_scan 函数中进行离散化操作
out_y.append(yi)
out_y = torch.stack(out_y, dim=1)
else:
out_y = selective_scan(
xs, dts,
As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L) # 在 selective_scan 函数中进行离散化操作
assert out_y.dtype == torch.float
""" 四种遍历路径叠加 (Mamba之后) """
# token位置还原
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# 四种状态叠加
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
# 还原形状,方便输出
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
y = self.out_norm(y).view(B, H, W, -1)
# z是一个门控(SiLU激活分支)
y = y * z
out = self.dropout(self.out_proj(y))
return out
推荐阅读
2、 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba_mamba模型-CSDN博客