对于大小为1024×1024像素的图像,它可以实现每张图像7毫秒(ms)的推理速度,比普通SAM快30.1倍,比最先进的方法快2.1倍。此外,它仅需要244 MB内存,相当于普通SAM的(3.5%)。推理一张图仅需7毫秒

论文链接:https://arxiv.org/pdf/2403.09195.pdf

代码 & 权重地址:https://anonymous.4open.science/r/SAM-LIGHTENING-BC25/

Segment Anything ModelSAM)由于其zero-shot泛化能力,在分割任务中引起了相当大的关注。然而,由于其低推理速度和高计算内存需求,SAM在实际应用中的广泛应用受到了限制,这主要源自注意力机制。现有工作集中在优化编码器上,然而尚未充分解决注意力机制本身的低效率问题,即使是在将其蒸馏到较小的模型时,这也为进一步改进留下了空间。

为此,这里引入了SAM-Lightening,SAM的一种变体,它具有重新设计的注意力机制,称为Dilated Flash Attention。它不仅促进了更高的并行性,增强了处理效率,而且还保留了与现有的Flash Attention的兼容性。相应地,提出了一种渐进式蒸馏方法,使得可以在不需要昂贵的从头训练的情况下,有效地从普通的SAM中传递知识。

在COCO和LVIS上的实验证明,SAM-Lightening在运行时效率和分割准确性方面明显优于最先进的方法。具体而言,对于大小为1024×1024像素的图像,它可以实现每张图像7毫秒(ms)的推理速度,比普通SAM快30.1倍,比最先进的方法快2.1倍。此外,它仅需要244 MB内存,相当于普通SAM的(3.5%)。

介绍

传统上,图像分割受到深度学习模型必须针对特定任务设计的数据集进行专门训练的限制。手工制作数据集的专门化通常限制了它们的生成能力。为了解决这一限制,Segment Anything Model(SAM)以其zero-shot学习能力代表了一种范式转变,使其能够对新的和未见过的图像进行分割。然而,SAM在增强现实(AR)、图像编辑、在智能手机上部署和医学成像等不同领域的应用受到其图像编码器中的计算负担挑战的影响,该编码器包含了巨大的6.32亿参数。这个尺寸大约是传统分割网络如U-Net的20倍,导致了高计算需求。

针对这一挑战,已经尝试了各种努力。例如,FastSAM采用了一种策略,即用更简化的卷积神经网络(CNN)替换SAM的transformer 编码器,旨在创建一个更轻量级的模型。然而,这往往会导致准确性降低,特别是在复杂的分割任务中。另一个显著的方法是MobileSAM,它采用蒸馏技术将知识从SAM的编码器传输到更紧凑的ViT-tiny编码器。类似地,像EfficientSAM这样的倡议旨在改进MobileSAM的训练过程以提高准确性。相反,SAMFast通过诸如量化和剪枝等技术,专注于对原始SAM的速度优化,但这些修改对性能提升的影响有限。

我们的研究确定了先前关于SAM的研究中的关键局限性,主要表现在注意力机制中计算效率低和内存使用不足方面。为了解决这些问题,将FlashAttention和Dilated Attention机制集成到我们的SAM框架中,提供了对现有方法的正交改进。这些增强不仅减少了内存消耗,还改善了并行处理,使它们与先前的进展相辅相成。然而,直接将这些机制应用于SAM将需要对模型进行完全重新训练,带来巨大的计算成本。

为了避免这一挑战,我们提出了一种动态分层蒸馏(DLD)方法。DLD通过逐渐分配特征权重,为图像编码器实现了一种渐进式蒸馏方案,有效促进了从SAM到轻量级模型的知识转移。实验证明,我们的模型(SAM-Lightening)不仅具有足够的表达能力来表示原始SAM,而且在计算上效率高,完成推理过程仅需7毫秒。

简而言之,本文的主要贡献有四个方面:

  • 引入了一种新颖的SAM结构,SAM-Lightening,以显著降低计算复杂性。
  • 设计了一种新颖的Dilated Flash Attention,用于取代普通的自注意力机制,以提高SAM-Lightening的效率和推理速度。
  • 为了有效地将知识从普通SAM传输到SAM-Lightening,提出了一种动态分层蒸馏方法,而不会影响性能。
  • SAM-Lightening实现了每张图像7毫秒的最先进性能,比普通SAM快30.1倍。

相关工作

Segment Anything Model:    SAM 由三个主要部分组成:图像编码器、提示编码器和mask解码器。值得注意的是,图像编码器是SAM中参数最多的部分,占据了其处理时间的98.3%,这突显了优化的必要性。FastSAM采用了CNN编码器,具体是YOLOv8-seg,来替换ViT编码器以提高处理速度。然而,已经观察到这种做法会损害分割的精度,特别是在复杂场景和捕捉细微边缘细节方面。MobileSAM对编码器进行了蒸馏,以减小模型大小和计算需求。然而,MobileSAM编码器结构和参数分布的不平衡限制了其在实际部署和性能优化方面的潜力。SAMFast 代表了另一种优化策略,专注于使用量化和稀疏化等方法提高SAM的处理速度。虽然这种方案确实提供了一些加速,但其整体影响仍然适中。另一方面,EfficientSAM改进了MobileSAM的训练方法,具体针对MobileSAM方法的准确性方面进行了优化。

FlashAttention: FlashAttention机制引入了一种在神经网络中计算注意力的高效准确的方法。它通过策略性的切片和重计算技术,主要实现了高带宽内存读写的显著减少。在此基础上,FlashAttention-2通过增强的矩阵乘法操作进一步改进了这一过程。这些改进已经被证明在特定的计算设置中可以实现性能增加高达两倍。

知识蒸馏:  知识蒸馏是一种将复杂模型的知识转移到简单模型的技术。它们旨在保留较大模型的性能属性,同时显著减少其计算占用和模型大小。MobileSAM采用了一种解耦的知识蒸馏,通过从原始SAM的ViT-H图像编码器中提取输出,并将其直接蒸馏到预训练的ViT-tiny编码器中。这种策略对于已经具有预训练参数的较小模型特别有益。

方法

Dilated Flash Attention

为了解决SAM图像编码器中高计算需求的问题,我们设计了一种新颖的注意力操作,利用FlashAttention来加快推理速度。

分割和稀疏化: 为了减轻在注意力操作中处理 $(Q, K, V)$ 中的计算负担, 将每个输入分成相等长度的部分 $(w)$ ,然后在每个部分的序列维度上应用稀疏化。这种稀疏化包括在固定间隔 $(r)$ 处选择行, 从而减少了注意力机制需要处理的数据量。如图1所示, 稀疏化过程可以表示为:

SAM-Lightening_数据集

FlashAttention的并行处理: 每个输入数据的稀疏化段是可以独立参与注意力计算的稠密矩阵,因此可以并行处理。这种并行性对于高效管理大规模图像数据集至关重要,可以显著加快处理速度,提高模型在实时图像分割中的效率。通过将FlashAttention纳入其中,进一步增加了通过并行化稠密矩阵计算来提高效率的可能性。

输出重组:  在提出的Dilated Flash Attention框架中, 我们并行处理稀疏化的段, 实现对 $\tilde{Q_i}$和 $\tilde{K}_i$ 转置的乘积应用 softmax函数, 然后将其与 $\tilde{V}_i$ 相乘, 如下所示:

SAM-Lightening_数据集_02

将这些输出重新组合成连贯的最终输出O,涉及一个经过精心设计的过程:

  • (1) 首先, 建立一个与原始输入维度相同的零矩阵 $O_{i n i t}$, 用于累积各个段的输出。
  • (2) 对于每个计算得到的段输出 $\tilde{Q_i}$, 确定一个特定的偏移量 $\gamma_i$ 。这个偏移量确定了 $\tilde{O}_i$ 在 $O_{i n i t}$ 矩阵中的精确起始位置。
  • (3) 使用基于其 $\gamma_i$ 的映射操作将每个 $\tilde{Q_i}$ 映射到 $O_{i n i t}$ :

SAM-Lightening_编码器_03

"MAP" 操作根据 $\gamma_i$ 确定的位置将每个 $\tilde{Q_i}$ 元素放置到 $O_{i n i t}$ 中。这确保了根据其原始输入位置, 每个段的输出在最终输出矩阵 $O$ 中的精确对齐。

计算效率: 通过提出的Dilated Flash Attention机制, 效率在数量上提高了一个因子 $\frac{N}{w r^2}$, 其中 $N$ 表示输入的总大小, $w$ 表示每个分割的长度, $r$ 表示稀疏化的间隔。这种数学关系表明, 对于任何给定的输入大小, Dilated Flash Attention需要的计算量大大减少。因此, 这提高了模型在高效处理大规模图像分割任务方面的能力, 显著提高了性能和实用性。

动态分层蒸馏 (DLD)

从头开始训练SAM-Lightening成本高昂,而由于SAM与以ViT-H为特征编码器的SAM-Lightening之间的结构差异,层适应性具有挑战性。为了实现从普通SAM到提出的框架的有效知识转移,我们提出了一种新颖的动态分层蒸馏(DLD)方法,该方法动态修改特征权重,以增强模型之间的逐层蒸馏。 

动态分层权重: 当前面的层次没有很好地被蒸馏时,后续层次的性能可能会受到从前面层次提取的低质量特征的影响。通过给予这些初始层损失更大的权重,动态加权确保它们在训练过程中更受关注。这有助于在初始阶段更好地将学生模型与教师模型对齐。给定一个由 $L$ 层组成的深度神经网络, 每一层 $i$ 与一个时间权重 $\alpha_i(t)$ 相关联。这种机制调整了神经网络中每一层 $i$ 在各个训练阶段 $t$ 中的重要性。初始层保持最大的焦点 $\alpha_1(t)=1$, 后续层遵循动态加权方案, 可以用分段函数在数学上表示为:

SAM-Lightening_人工智能_04

其中, $T_i$ 表示第 $i^{t h}$ 层开始更新权重的时期, 前一层已经达到饱和, 即 $T_i=T_{i-1}+\Delta t$ 。参数 $\Delta t$ 捕获了权重从 0 过渡到 1 所需的epoch数。对于预定义的epoch增量 $\Delta t$, 每一层在前一层达到峰值权重后依次激活其学习潜力。这种机制有助于从教师模型中实现级联知识吸收。

解耦的特征蒸馏: 蒸馏过程将知识从SAM的编码器(教师模型)转移到我们提出的编码器(学生模型),如图1所示。选择了距离输出最近的层进行特征蒸馏。由于这些更深层次直接与模型的输出相关,蒸馏它们可以更有效地传递关键信息以获得预测结果。这些层被指定为“焦点层”。

在训练的初始阶段,靠近输入的层次被赋予优先权。这里的意图是将SAM-Lightning学生模型的主要特征表示,表示为 $f_{S A M-L}^i(x)$, 与教师模型的特征表示 $f_{S A M}^i(x)$ 对齐,与教师模型的特征表示对齐,对于最接近输入的层。随着训练的进行,逐层加权动态地转移。与后续层相关的损失被逐渐放大。在这个过程中,损失函数演变为吸收来自后续层的表示:

SAM-Lightening_人工智能_05

其中, $L$ 是完整层次的计数, 而系数 $\alpha(i)$ 是由训练epoch和层次 $i$ 确定的分段函数。集成的蒸馏损失被表述为:

SAM-Lightening_权重_06

其中, $L_P$ 封装了所有选择的特征层损失的加权和, $L_{\text {output }}$ 是图像编码器输出层的损失, $\lambda$是一个缩放因子, 用于平衡解码器输出在整体蒸馏过程中的重要性。

对齐解码器: 此外,通过解耦的蒸馏获得的轻量级图像编码器与冻结的解码器存在对齐问题,特别是对于基于点的提示分割任务。因此,我们通过在SA-1B数据集上对点提示和框提示进行采样,对解码器进行微调,以与图像编码器对齐。损失函数定义如下:

SAM-Lightening_权重_07

这里,IOU代表交并比损失,而Dice损失和Focal损失分别用于处理类别不平衡和具有挑战性的分割区域。

实验

实验设置

我们的模型利用SA-1B数据集的1%进行蒸馏和微调。它具有一个embedding维度为384的编码器,六个注意力头和一个六层结构。对于FlashAttention组件,使用bfloat16。蒸馏和微调过程分别进行了10个时期,学习率为$10^{-3}$,批量大小为32。梯度累积设置为4步。该模型在两个NVIDIA RTX 4090 GPU上进行训练。为了提高训练速度,SAM的图像编码器的输出被保存了下来。

结果

运行时间和内存效率评估: 将我们提出的SAM-Lightening与原始SAM(即SAM-ViT-H)、FastSAM、MobileSAM、EfficientSAM和SAMFast在表1和表2中进行了性能比较。就分割性能而言,原始SAM被视为上限。重要的是,表1显示SAM-Lightening在推理延迟和峰值内存使用方面优于所有对手,与原始SAM相比,实现了30.1倍的加速和96.5%的峰值内存减少,并且与现有技术相比实现了2.1倍的加速。表2中的吞吐量比较进一步强化了SAM-Lightening的优越性能,它在各种batch大小下均实现了最高的吞吐量。总的来说,这种高吞吐量与低延迟和内存使用率相结合,使SAM-Lightening成为图像分割任务的高效模型。

SAM-Lightening_权重_08

SAM-Lightening_编码器_09

在框提示/点提示模式下的比较: 首先在边界框和基于点的提示下评估性能。对于边界框提示,按照原始SAM中的设置,利用COCO和LVIS中的真值标注,合成定义每个图像中感兴趣区域的边界框。对于点提示,在真值mask中随机采样点,挑战所有模型准确地分割与每个点相关的目标或区域。在定量上,使用平均交并比(mIoU)作为指标。如表3所示,与原始SAM相比,SAMFast和MobileSAM在边界框提示和点提示方面都表现出性能下降,特别是在点提示方面。作为基于CNN的模型,FastSAM显示出更为明显的下降,特别是在处理包含大量小目标的LVIS数据集时,这一点尤为明显。这种观察反映了CNN编码器在处理更复杂的分割场景时的局限性。相反,SAM-Lightening在分割性能方面与原始SAM最为匹配。即使在基于点的提示场景下,SAM-Lightening的mIoU也与原始SAM相似。

SAM-Lightening_编码器_10

SAM-Lightening_权重_11

在任意模式下的比较: 虽然"segment-anything"模式是一种创新的方法,但并不是一种常用的分割方法,因此并不能有效地代表典型的分割任务。因此,我们的分析主要集中在通过基于点和基于框的方法在视觉上比较分割结果,这在实际应用中更为普遍。然而,为了完整性和展示模型的多功能性,还在比较中包含了"segment-anything"模式的输出。

从图3中展示的代表样本, SAM-Lightening和MobileSAM在分割结果上几乎与原始SAM无法区分。这种相似性在边缘清晰度和细节保留方面非常明显,这是高质量分割的特征。SAM-Lightening展示了其稳健性和准确性,与原始SAM的性能密切对齐。

SAM-Lightening_编码器_12

消融研究

值得注意的是,许多先前的工作在SAM的输入尺寸上使用小于1024的尺寸。为了公平比较,我们也在这些场景下进行了实验,并发现保持FlashAttention的输入尺寸等于或小于512×512可以实现最佳性能。这表明FlashAttention的适用性取决于模型的输入尺寸和特定的硬件配置。决定使用FlashAttention应基于特定的应用背景和性能要求。

虽然FlashAttention可以加速模型蒸馏的训练过程,但其对推理性能的影响取决于各种硬件指标。在我们的推理平台上,特别是对于输入尺寸为1024的SAM,多头注意力运算符表现出更多的计算密集型特征。如下图4所示,这导致使用FlashAttention的推理速度略低于不使用FlashAttention的速度。因此,我们选择在蒸馏过程中使用FlashAttention来优化性能,而在评估阶段移除它。

SAM-Lightening_权重_13

结论

SAM-Lightening用来解决原始SAM中高计算需求和推理速度慢的主要局限性,使其更适合部署在资源受限设备上。方法涉及对SAM中图像编码器的重新设计,通过将自注意力操作符蒸馏成具有动态层次蒸馏的Dilated Flash Attention。这些优化措施显著降低了计算复杂度和内存使用量,同时没有损害分割性能。

具体来说,SAM-Lightening在图像上完成推理平均每张仅需7ms,实现了比SAM-ViT-H快30.1倍的速度提升。由于SAM-Lightening与剪枝和量化是互补的,未来的一个方向可以考虑将它们整合在一起。