SAM轻量化改进SAM-Lightening论文解读SAM-LIGHTENING: A LIGHTWEIGHT SEGMENT ANYTHING MODEL WITHIN DILATED FLASH

本文介绍了一种新型的模型SAM-Lightening,通过优化注意力机制和设计DilatedFlashAttention,显著提高图像分割的运行时间和内存效率。对比实验显示,SAM-Lightening在COCO和LVIS上表现优于现有方法,尤其适合资源受限设备。
摘要由CSDN通过智能技术生成

现已总结SAM多方面相关的论文解读,具体请参考该专栏的置顶目录篇

一、总结

1. 简介

发布时间:2024年3月14日

论文:

https://arxiv.org/pdf/2403.09195.pdf

代码:

Anonymized Repository - Anonymous GitHub (4open.science)

2. 摘要

        由于SAM低推理速度和高计算内存需求,SAM在实际应用中的广泛应用受到了限制,这主要源自注意力机制。现有工作集中在优化编码器上,然而尚未充分解决注意力机制本身的低效率问题,即使是在将其蒸馏到较小的模型时,这也为进一步改进留下了空间。
        为此,SAM-Lightening对SAM中图像编码器的重新设计,通过将自注意力操作符蒸馏成具有动态层次蒸馏的Dilated Flash Attention。它不仅促进了更高的并行性,增强了处理效率,而且还保留了与现有的Flash Attention的兼容性。相应地,提出了一种渐进式蒸馏方法,使得可以在不需要昂贵的从头训练的情况下,有效地从普通的SAM中传递知识。
        在COCO和LVIS上的实验证明,SAM-Lightening在运行时效率和分割准确性方面明显优于最先进的方法,更适合部署在资源受限设备上。对于大小为1024×1024像素的图像,它可以实现每张图像7ms的推理速度,比SAM-ViT-H快30.1倍,比最先进的方法快2.1倍。此外,它仅需要244 MB内存,相当于普通SAM的(3.5%)。由于SAM-Lightening与剪枝和量化是互补的,未来的一个方向可以考虑将它们整合在一起。

3. 前言

        SAM编码器包含了巨大的6.32亿参数。这个尺寸大约是传统分割网络如U-Net的20倍,导致了高计算需求。针对这一挑战,FastSAM采用了一种策略,即用更简化的卷积神经网络(CNN)替换SAM的transformer 编码器,旨在创建一个更轻量级的模型。然而,这往往会导致准确性降低,特别是在复杂的分割任务中。MobileSAM采用蒸馏技术将知识从SAM的编码器传输到更紧凑的ViT-tiny编码器。EfficientSAM旨在改进MobileSAM的训练过程以提高准确性。相反,SAMFast通过诸如量化和剪枝等技术,专注于对原始SAM的速度优化,但这些修改对性能提升的影响有限。
        SAM-Lightening确定了先前关于SAM的研究中的关键局限性,主要表现在注意力机制中计算效率低内存使用不足方面。将FlashAttentionDilated Attention机制集成到我们的SAM框架中,提供了对现有方法的正交改进。这些增强不仅减少了内存消耗,还改善了并行处理,使它们与先前的进展相辅相成。
        然而,直接将这些机制应用于SAM需要对模型进行重新训练,带来巨大的计算成本。因此提出了一种动态分层蒸馏(DLD)方法。DLD通过逐渐分配特征权重,为图像编码器实现了一种渐进式蒸馏方案,有效促进了从SAM到轻量级模型的知识转移。实验证明,SAM-Lightening不仅具有足够的表达能力来表示原始SAM,而且在计算上效率高,完成推理过程仅需7毫秒。

4. 贡献

(1)引入了一种新颖的SAM结构,SAM-Lightening,以显著降低计算复杂性。
(2)设计了一种新颖的Dilated Flash Attention,用于取代普通的自注意力机制,以提高SAM-Lightening的效率和推理速度。
(3)为了有效地将知识从普通SAM传输到SAM-Lightening,提出了一种动态分层蒸馏方法,而不会影响性能。
(4)SAM-Lightening实现了每张图像7毫秒的最先进性能,比普通SAM快30.1倍。

二、模型结构

        SAM由三部分组成,分别为图像编码器(Image Encoder),提示编码器(Promt Encoder),掩码解码器(Mask Decoder),从头开始训练SAM-Lightening成本高昂,因此在输入一组图像后,进入由原始SAM的图像编码器进行动态分层蒸馏(DLD)成的轻量级图像编码器提取出图像嵌入,对原始SAM图像编码器及提示编码器(输出点提示和框提示的提示嵌入)进行冻结(即权重在训练过程中不更新),而轻量级图像编码器和掩码解码器是可学习的,将图像嵌入与提示嵌入输入到掩码解码器中生成分割掩码。

1. 动态层级蒸馏(DLD)

        在这个过程中,老师模型是一个预先训练好的、性能较强的模型——SAM的编码器。老师模型结构中包含多个模块,如多头自注意力(MHA)和多层感知器(MLP),它处理输入的批量样本并生成输出。学生模型是一个更轻量级的模型,通过模仿老师模型的输出来提高性能。学生模型的结构与老师模型相似,但在大小和容量上有所减小。它也包含轻量化的自注意力机制Lightening MHA和MLP模块。比较老师和学生模型的中间表示,并计算它们之间的损失函数。为了实现从SAM到SAM-Lightening的有效知识转移,提出了一种新颖的动态分层蒸馏(DLD)方法,该方法动态修改特征权重,以增强模型之间的逐层蒸馏。具体过程如下:
(1)动态分层权重:当前面的层次没有很好地被蒸馏时,后续层次的性能可能会受到从前面层次提取的低质量特征的影响。通过给予这些初始层损失更大的权重,动态加权确保它们在训练过程中更受关注。这有助于在初始阶段更好地将学生模型与教师模型对齐。给定一个由L层组成的深度神经网络, 每一层与一个时间权重相关联。这种机制调整了神经网络中每一层在各个训练阶段中的重要性。初始层保持最大的焦点, 后续层遵循动态加权方案。这种机制有助于从教师模型中实现级联知识吸收。
(2)解耦的特征蒸馏: 蒸馏过程将知识从教师模型转移到学生模型,选择了距离输出最近的层进行特征蒸馏。由于这些更深层次直接与模型的输出相关,蒸馏它们可以更有效地传递关键信息以获得预测结果。这些层被指定为“焦点层”。在训练的初始阶段,靠近输入的层次被赋予优先权。这里的意图是将SAM-Lightning学生模型的主要特征表示,表示为f_{S A M-L}^i(x_{j})与教师模型的特征表示 f_{S A M}^i(x_{j})对齐,与教师模型的特征表示对齐,对于最接近输入的层。随着训练的进行,逐层加权动态地转移。与后续层相关的损失被逐渐放大。在这个过程中,损失函数演变为吸收来自后续层的表示:

(3)对齐解码器: 此外,通过解耦的蒸馏获得的轻量级图像编码器与冻结的解码器存在对齐问题,特别是对于基于点的提示分割任务。因此,我们通过在SA-1B数据集上对点提示和框提示进行采样,对解码器进行微调,以与图像编码器对齐。

2. 注意力机制(Dilated Flash Attention)

        为了解决SAM图像编码器中高计算需求的问题,利用注意力机制FlashAttention来加快推理速度。学生模型底部展示了一个自注意力模块,其中Q、K、V分别代表查询(Query)、键(Key)和值(Value),这是Transformer架构的标准组件。自注意力模块经过轻量化处理,采用Dilated Flash Attention来降低计算复杂度。具体过程如下:
(1)分割和稀疏化:为了减轻在注意力操作中处理(Q, K, V)中的计算负担, 将输入在序列维度上分割成等长的段并应用稀疏化,选取固定间隔处的行,减少了注意力计算所需处理的数据量。
(2)FlashAttention并行处理:对输入数据进行稀疏化处理后,每个段可以作为独立的稠密矩阵参与注意力计算,实现并行处理。这种并行化对于有效处理大规模数据集至关重要,能够显著提升处理速度,提高实时图像分割任务的效率。
(3)输出重组:在Dilated Flash Attention框架中,并行处理稀疏化的段,将\widetilde{​{Q_i}}\widetilde{​{K_i}}转置后乘积的结果经softmax函数处理后与\widetilde{​{V_i}}相乘,然后再重新组合成一个连贯的最终输出\widetilde{​{O_i}}

        该过程主要包括以下几步:
        1)首先, 建立一个与原始输入维度相同的零矩阵, 用于累积各个段的输出。
        2)对于每个计算得到的段输出\widetilde{​{Q_i}},确定一个特定的偏移量\gamma_i。这个偏移量\gamma_i确定了最终输出\widetilde{​{O_i}}在零矩阵中的精确起始位置。
        3)使用基于其\gamma_i的映射操作将每个段输出\widetilde{​{Q_i}}映射到零矩阵。

        "MAP" 操作根据偏移量\gamma_i确定的位置将每个段输出\widetilde{​{Q_i}}元素放置到零矩阵中。这确保了根据其原始输入位置, 每个段的输出在最终输出矩阵中的精确对齐。
(4)计算效率:通过提出的Dilated Flash Attention机制, 效率在数量上提高了一个因子\frac{N}{w r^2},其中 N表示输入的总大小,w表示每个分割的长度,r表示稀疏化的间隔。这种数学关系表明, 对于任何给定的输入大小, Dilated Flash Attention需要的计算量大大减少。因此, 这提高了模型在高效处理大规模图像分割任务方面的能力, 显著提高了性能和实用性。

3. 多头注意力机制

        在学生模型的右侧,展示了“学生”模型中多头自注意力(multi-head attention)机制的注意力集中点。在这个情境中,每个头(head)都能独立关注输入数据的不同部分,这使得模型能够从多个角度捕捉信息。不同的颜色可能表示注意力权重的不同,注意力图用颜色的深浅来表示注意力权重的大小,深色代表更高的注意力权重,即模型对这部分输入数据的关注度更高。
        第1个头部:专注于图像的某个特定区域(如红色所示),这是一个关键特征,比如某个物体或图像的某部分结构。
        第2个头部:聚焦在不同的区域(橙色所示),捕捉与第一个头部不同的信息,为模型提供多样化的输入理解。
        第3个头部:关注点与前两个头部都不同(黄色所示),提供了更加广泛的数据理解,有助于模型综合多方面的信息以做出更准确的预测或决策。
        这种多头注意力机制的设计允许模型在处理复杂任务时,能够对不同的特征区域进行并行关注和学习,从而提高整体的识别性能。在训练过程中,这些注意力权重是学习得来的,它们决定了模型在生成输出时哪些输入信息更加重要。

三、实验

1. 实验设置

        利用SA-1B数据集的1%进行蒸馏和微调。它具有一个embedding维度为384的编码器,六个注意力头和一个六层结构。对于FlashAttention组件,使用bfloat16。蒸馏和微调过程分别进行了10个时期,学习率为10^{-3},批量大小为32。梯度累积设置为4步。该模型在两个NVIDIA RTX 4090 GPU上进行训练。为了提高训练速度,SAM的图像编码器的输出被保存了下来。

2. 结果

2.1 运行时间和内存效率评估

         将我们提出的SAM-Lightening与原始SAM(即SAM-ViT-H)、FastSAM、MobileSAM、EfficientSAM和SAMFast在下面两个表中进行了性能比较。
        就分割性能而言,原始SAM被视为上限。重要的是,下表显示SAM-Lightening在推理延迟和峰值内存使用方面优于所有对手,与原始SAM相比,实现了30.1倍的加速和96.5%的峰值内存减少,并且与现有技术相比实现了2.1倍的加速。

        下表中的吞吐量比较进一步强化了SAM-Lightening的优越性能,它在各种batch大小下均实现了最高的吞吐量。总的来说,这种高吞吐量与低延迟和内存使用率相结合,使SAM-Lightening成为图像分割任务的高效模型。

2.2 框提示/点提示模式下的比较

        对于边界框提示,按照原始SAM中的设置,利用COCO和LVIS中的真值标注,合成定义每个图像中感兴趣区域的边界框。对于点提示,在真值mask中随机采样点,挑战所有模型准确地分割与每个点相关的目标或区域。在定量上,使用平均交并比(mIoU)作为指标。如下表所示,与原始SAM相比,SAMFast和MobileSAM在边界框提示和点提示方面都表现出性能下降,特别是在点提示方面。作为基于CNN的模型,FastSAM显示出更为明显的下降,特别是在处理包含大量小目标的LVIS数据集时,这一点尤为明显。这种观察反映了CNN编码器在处理更复杂的分割场景时的局限性。相反,SAM-Lightening在分割性能方面与原始SAM最为匹配。即使在基于点的提示场景下,SAM-Lightening的mIoU也与原始SAM相似。下图为在提示模式下SAM- lightening和SAM的代表性图像分割结果。

2.3 在任意模式下的比较

        为了完整性和展示模型的多功能性,还在比较中包含了"segment-anything"模式的输出。从下图中展示的代表样本, SAM-Lightening和MobileSAM在分割结果上几乎与原始SAM无法区分。这种相似性在边缘清晰度和细节保留方面非常明显,这是高质量分割的特征。SAM-Lightening展示了其稳健性和准确性,与原始SAM的性能密切对齐。

3. 消融实验 

        许多先前的工作在SAM的输入尺寸上使用小于1024的尺寸。为了公平比较,我们也在这些场景下进行了实验,并发现保持FlashAttention的输入尺寸等于或小于512×512可以实现最佳性能。这表明FlashAttention的适用性取决于模型的输入尺寸和特定的硬件配置。决定使用FlashAttention应基于特定的应用背景和性能要求。
        虽然FlashAttention可以加速模型蒸馏的训练过程,但其对推理性能的影响取决于各种硬件指标。在我们的推理平台上,特别是对于输入尺寸为1024的SAM,多头注意力运算符表现出更多的计算密集型特征。如下图所示,这导致使用FlashAttention的推理速度略低于不使用FlashAttention的速度。因此,我们选择在蒸馏过程中使用FlashAttention来优化性能,而在评估阶段移除它。

  • 20
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值