Segment Anything 模型结构分析

SAM

首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!
在这里插入图片描述

SAM工作就是用新的prompt engineer+预训练大模型的范式来对图像进行分割,以实现zero-shot(旧的范式是pretrain+finetune)

整个SAM架构可以分成三个大部分,image encoder部分prompt encoder部分mask decoder部分。下面一一介绍。

Image Encoder

SAM的image encoder部分用的是MAE预训练的ViT,ViT这里就不做介绍了。原始图像被等比和padding的缩放到1024大小,采用kernel size为16,stride为16的卷积将图像离散化为64×64×768的向量,铺平后进入transformer encoder,输出的向量再通过两层的卷积压缩到embedding dimension为256。

image encoder这一部分的计算以及存储消耗是非常大的,在META官方的demo中,image embedding的计算也是在云端服务器中进行的。所以要实现模型轻量化,对这一部分需要做改进。

prompt encoder

在META官方的demo中,可以通过给定一个点位(point)来进行语义分割,如下图所示。
在这里插入图片描述

也可以框选一个区域,来进行语义分割,如下图所示。
在这里插入图片描述

此外,论文中还提到text prompt。这个功能在demo中没有展现,个人理解就是给一个我想要分割的区域的描述,SAM根据描述进行相应区域的分割。

上面说到的三种prompts在论文中归类为稀疏类prompt(sparse prompt)。point和box(左上角的点&右下角的点)采用position embedding(transformer里的东西,是一种用sines和cosines组成的编码,能够表示一个东西的相对位置和顺序关系)+learnable cls embeddings作为embedding;(这个部分可以看一下代码

text prompt同样也是稀疏类prompt,但显然不能用pe来表示它。SAM中对应于text的encoder是CLIP架构中的text encoder,具体可以看CLIP的相关内容。

还有一个prompt是mask,采用卷积神经网络进行下采样后和image embedding进行element-wise相加(使得,就是1+1=2的加,反正都挺玄学的)

mask decoder

下图是论文中给出的mask decoder的结构
在这里插入图片描述
相信大部分人和我一样,乍一眼看,一脸懵逼,这么多箭头,而且论文中对它的描述也很少。那我们从左往右来分析。

image embedding和prompt embedding就是上面提到的prompt部分的内容。而output tokens前面并没有提到,其实看过ViT的同学应该对这个玩意儿不陌生,VIT做的是分类任务,在image embedding的最前面加了一个cls token,在好几层的self attention之后,输出的这个cls token就是对应的目标类别。这里也是同理,SAM做的是语义分割任务,但是输出不止一个mask,如下图所示。
在这里插入图片描述
这个应该是针对于point prompt来说的,拿论文中这个剪刀举例。我point点在剪刀柄上,我想要分割的区域可能会是上面三种的其中一种,也就是“全部”、“部分”、“子部分”。那么根据什么来展示出最后的输出呢,就涉及到这个output tokens,一个output mask对应一个output tokens,还有一个IoU prediction head来选择三个mask中它认为最好的输出(这个IoU prediction head是模型中的一个learnable分支,在训练模型时根据GT来训练)。

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.

        Arguments:
          image_embeddings (torch.Tensor): the embeddings from the image encoder
          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
          multimask_output (bool): Whether to return multiple masks or a single
            mask.

        Returns:
          torch.Tensor: batched predicted masks
          torch.Tensor: batched predictions of mask quality
        """
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor
### Segment Anything 模型训练方法 #### 数据准备 为了有效地训练Segment Anything (SAM)模型,数据集的选择至关重要。此模型采用了一种特殊的预训练策略,在训练期间为每个样本模拟多种提示,并将预测结果与真实的标签进行比较[^1]。 #### 预训练策略 在预训练阶段,模型通过模仿不同的用户输入方式来学习如何响应各种类型的提示。这不仅限于精确的位置点击或是边界框定义,还包括更抽象的概念指导。这样的设计使得即使面对模糊不清的任务描述,模型也能够提供合理的分割掩码输出。 #### 架构调整 当涉及到具体的实现细节时,可以根据需求定制化修改原始的SAM架构或创建类似的新型结构。对于那些希望探索新领域应用的研究人员来说,理解并实验这两种主要类别下的具体差异是非常有益的[^4]。 #### 实践指南 实际操作上,建议访问「OpenBayes」平台上的官方文档和教程资源,这里提供了详细的源代码解释以及在线推理环境设置说明,帮助开发者快速入门并掌握使用技巧[^3]。 ```python import torch from segment_anything import sam_model_registry, SamPredictor device = "cuda" if torch.cuda.is_available() else "cpu" model_type = "vit_b" sam_checkpoint = "./checkpoint/sam_vit_b_01ec64.pth" model = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) predictor = SamPredictor(model) image_path = 'path_to_your_image.jpg' # 加载图片并执行其他必要的初始化步骤... ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值