Segment Anything 模型结构分析

SAM

首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!
在这里插入图片描述

SAM工作就是用新的prompt engineer+预训练大模型的范式来对图像进行分割,以实现zero-shot(旧的范式是pretrain+finetune)

整个SAM架构可以分成三个大部分,image encoder部分prompt encoder部分mask decoder部分。下面一一介绍。

Image Encoder

SAM的image encoder部分用的是MAE预训练的ViT,ViT这里就不做介绍了。原始图像被等比和padding的缩放到1024大小,采用kernel size为16,stride为16的卷积将图像离散化为64×64×768的向量,铺平后进入transformer encoder,输出的向量再通过两层的卷积压缩到embedding dimension为256。

image encoder这一部分的计算以及存储消耗是非常大的,在META官方的demo中,image embedding的计算也是在云端服务器中进行的。所以要实现模型轻量化,对这一部分需要做改进。

prompt encoder

在META官方的demo中,可以通过给定一个点位(point)来进行语义分割,如下图所示。
在这里插入图片描述

也可以框选一个区域,来进行语义分割,如下图所示。
在这里插入图片描述

此外,论文中还提到text prompt。这个功能在demo中没有展现,个人理解就是给一个我想要分割的区域的描述,SAM根据描述进行相应区域的分割。

上面说到的三种prompts在论文中归类为稀疏类prompt(sparse prompt)。point和box(左上角的点&右下角的点)采用position embedding(transformer里的东西,是一种用sines和cosines组成的编码,能够表示一个东西的相对位置和顺序关系)+learnable cls embeddings作为embedding;(这个部分可以看一下代码

text prompt同样也是稀疏类prompt,但显然不能用pe来表示它。SAM中对应于text的encoder是CLIP架构中的text encoder,具体可以看CLIP的相关内容。

还有一个prompt是mask,采用卷积神经网络进行下采样后和image embedding进行element-wise相加(使得,就是1+1=2的加,反正都挺玄学的)

mask decoder

下图是论文中给出的mask decoder的结构
在这里插入图片描述
相信大部分人和我一样,乍一眼看,一脸懵逼,这么多箭头,而且论文中对它的描述也很少。那我们从左往右来分析。

image embedding和prompt embedding就是上面提到的prompt部分的内容。而output tokens前面并没有提到,其实看过ViT的同学应该对这个玩意儿不陌生,VIT做的是分类任务,在image embedding的最前面加了一个cls token,在好几层的self attention之后,输出的这个cls token就是对应的目标类别。这里也是同理,SAM做的是语义分割任务,但是输出不止一个mask,如下图所示。
在这里插入图片描述
这个应该是针对于point prompt来说的,拿论文中这个剪刀举例。我point点在剪刀柄上,我想要分割的区域可能会是上面三种的其中一种,也就是“全部”、“部分”、“子部分”。那么根据什么来展示出最后的输出呢,就涉及到这个output tokens,一个output mask对应一个output tokens,还有一个IoU prediction head来选择三个mask中它认为最好的输出(这个IoU prediction head是模型中的一个learnable分支,在训练模型时根据GT来训练)。

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.

        Arguments:
          image_embeddings (torch.Tensor): the embeddings from the image encoder
          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
          multimask_output (bool): Whether to return multiple masks or a single
            mask.

        Returns:
          torch.Tensor: batched predicted masks
          torch.Tensor: batched predictions of mask quality
        """
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        # output_tokens由IoU_token和mask_token组成,将他们进行拼接
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        #将原来2维的output_tokens空间化后按照sparse_prompt_embedding的维度进行扩展(sparse_prompt_embedding的维度是什么呢?)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        #将扩展后的output_tokens和sparse_prompt_embeddings进行拼接
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        # 我的理解就是有几个tokens就把image_embedding扩充几次,这个tokens应该和prompt的类型有关
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        # 把image_embeddings和mask_embeddings加上
        src = src + dense_prompt_embeddings
        # 和ViT一样,加上learnable_position_encoding
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        # token to image attention(先按下不表)
        #这里的transformer整合了右下角的token to image attn.后面再说
        hs, src = self.transformer(src, pos_src, tokens) 
        # 从hs中分离出两个token
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        # 梯形的那个上采样
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred

那么上面先按下不表的内容我们下面就可以拿出来说了,在上面的代码中的transformer并不只是正常框架的transformer,里面包含了mask decoder中左边深色框住部分以及右下角的token to image attn部分,所以上面把image embedding、output token、prompt token处理完了之后,一个transformer出来就直接给2xconv和分output_token了。

在论文中提到,mask decoder用了TwoWayTransformer。在代码中,我们可以在segment-anything/segment_anything/build_sam.py路径下找到MaskDecoder类的相关代码,如下图所示。
在这里插入图片描述

可以看到,transformer调用的是TwoWayTransfromer类,depth为2就是结构示意图中的×2。
在这里插入图片描述

根据segment-anything/segment_anything/modeling/transformer.py路径可以找到TwoWayTransformer的定义。相同目录下也能找到TwoWayAttentionBlock,上面也提到了,transformer包括深色框选部分以及右下角的token to image attn.,那么这个TwoWayAttentionBlock就是pure的深色框选部分的代码。

先来看整体的TwoWayTransformer的代码

class TwoWayTransformer(nn.Module):
    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
    ) -> None:
        """
        A transformer decoder that attends to an input image using
        queries whose positional embedding is supplied.

        Args:
          depth (int): number of layers in the transformer
          embedding_dim (int): the channel dimension for the input embeddings
          num_heads (int): the number of heads for multihead attention. Must
            divide embedding_dim
          mlp_dim (int): the channel dimension internal to the MLP block
          activation (nn.Module): the activation to use in the MLP block
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()
        # 这两个注释之间的内容就是实现框架的结构,depth为2,start
        for i in range(depth):
            self.layers.append(
                # 深色框选部分
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )
        # 显然这个就是右下角的token to image attn.
        self.final_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm_final_attn = nn.LayerNorm(embedding_dim)
        # 到此能说明结构走向 end
    def forward(
        # 这部分就是前面两个encoder得到的结果,对于point prompt和box prompt来说其实都是点,无非前者是一个点,后者是两个点
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape
            B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must
            have the same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.

        Returns:
          torch.Tensor: the processed point_embedding
          torch.Tensor: the processed image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        # 显然,image embedding是四维的,这里是将第三维开始展开,将后面的维度转化为一维
        # 那么具体到image embedding,第三维和第四维也就是high和width
        # 然后把三个维从BxCxHxW调整为BxHWxC
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # Prepare queries
        queries = point_embedding
        keys = image_embedding

        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )

        # Apply the final attention layer from the points to the image
        q = queries + point_embedding # 为啥要加两回?
        k = keys + image_pe # 合理的,image_embedding直接加上position_encoding
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)  # 这里的norm_final_attn就是一个Layernorm
        # 这里还要再看一下
        return queries, keys

好,看完整个的结构,再来单独看深色框选部分的代码,也就是TwoWayAttentionBlock

class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        A transformer block with four layers: (1) self-attention of sparse
        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
        block on sparse inputs, and (4) cross attention of dense inputs to sparse
        inputs.

        Arguments:
          embedding_dim (int): the channel dimension of the embeddings
          num_heads (int): the number of heads in the attention layers
          mlp_dim (int): the hidden dimension of the mlp block
          activation (nn.Module): the activation of the mlp block
          skip_first_layer_pe (bool): skip the PE on the first layer
        """
        super().__init__()
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )

        self.skip_first_layer_pe = skip_first_layer_pe

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        # Self attention block
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # Cross attention block, tokens attending to image embedding
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # Cross attention block, image embedding attending to tokens
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys

到这里,SAM的模型结构差不多就结束了,但是离理解SAM还差一个核心,也就是数据的设计,这在论文中也花了非常大的篇幅进行介绍,也是理解SAM非常重要的一环,应为模型中的很多设计都和数据的设计有关。这里先挖坑(20240327),希望有时间能来填上。

  • 27
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值