sam模型中的mask decoder

本文主要讲解了sam模型中的MaskDecoder部分。


前言

强烈建议先看明白self-attention。
up主霹雳吧啦的attention博客
up主霹雳吧啦的attention视频

一、MaskDecoder

我们先在segment_anything/modeling/mask_decoder.py中总览一下MaskDecoder部分:
其初始化为:

    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
    ) -> None:
        """
        Predicts masks given an image and prompt embeddings, using a
        transformer architecture.

        Arguments:
          transformer_dim (int): the channel dimension of the transformer
          transformer (nn.Module): the transformer used to predict masks
          num_multimask_outputs (int): the number of masks to predict when disambiguating masks
          activation (nn.Module): the type of activation to use when upscaling masks
          iou_head_depth (int): the depth of the MLP used to predict mask quality
          iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)
        self.num_mask_tokens = num_multimask_outputs + 1
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
		
		#两次转置卷积,用于上采样
        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
            activation(),
        )
        #图中右边上面那个MLP
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

		#图中右边下面那个MLP
        self.iou_prediction_head = MLP(
            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
        )

其中self.transformer = transformer即为左边深橙色部分;self.iou_tokenself.mask_tokens不是从前面的输出中得来的,而是凭空出现的、可学习的,类似于vit中的位置编码(不了解的仍然可以参考up主霹雳啪啦的相关视频和博客)。self.output_upscaling即两次转置卷积,即图中2x conv trans,每次起到两倍上采样的作用,共放大四倍;self.iou_prediction_head即为iou输出前的那次mlp层,self.output_hypernetworks_mlps为它上方的那个mlp层;你可能会好奇:右下方的那个token to image attn在哪,它并不在这里,而是被一并写入了左边的transformer代码块里面,后面细说。
在这里插入图片描述

forword部分为:

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.

        Arguments:
          image_embeddings (torch.Tensor): the embeddings from the image encoder
          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
          multimask_output (bool): Whether to return multiple masks or a single
            mask.

        Returns:
          torch.Tensor: batched predicted masks
          torch.Tensor: batched predictions of mask quality
        """
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred

可以看到:其forword过程主要并没有写在forward本身中,而是写在了predict_masks函数中,然后在forward中调用了predict_masks。重点看predict_masks函数:在经过tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 这一步后,此时即完成了图中output_tokens+prompt_token这一步,而image_embeddings和 image_pe并没有直接传入self.transformer中而是转化为src和pos_src再传入的。注意:不要把hs, src = self.transformer(src, pos_src, tokens)这句中self.transformer中传入的三个参数当作attention里面的q、k、v,这里后面细说,同时注意:此时得到了两个输出hs, src,hs是已经经过右下方那个token to image attn之后的结果,而src是上方还未经过2x conv trans的结果,之后的过程就和图片一样了,从hs中分离出iou部分和mask部分、mlp层、src经过四倍上采样、做一次矩阵乘法…得到各自结果

二、TwoWayTransformer

现在我们已经把橙色图中右边的细枝末节都过了一遍,但左边深色部分我们还没看,现在我们来看一下吧。MaskDecoder初始化的时候,self.transformer = transformer这句定义了MaskDecoder这个类中的transformer,但我们还不知道这个transformer具体是啥样,于是我们可以去看一下它的实例化;在segment_anything/modeling/mask_decoder.py中最上方的class MaskDecoder(nn.Module):按住ctrl、鼠标左击MaskDecoder,查看它在哪里被使用了,我们从这里进入了segment_anything/build_sam.py,然后可以看到下图:
在这里插入图片描述

可以看到这里使用的transformer是TwoWayTransformer这个类,这个类对应的就是下图中左边深色部分(右上角的x2刚好和TwoWay对应)(其实也不完全对应,上面提到过,这个TwoWayTransformer其实包括了左边深色部分和右下方的那个token to image attn,更准确地说,深色部分和TwoWayAttentionBlock类对应)
在这里插入图片描述

此时我们可以按住ctrl,鼠标左击TwoWayTransformer进入其中查看,可以看到:此python文件中定义了三个类:TwoWayTransformerTwoWayAttentionBlockAttention,其中Attention为最基础的注意力,不再描述,不明白的可以参考前言中的两个链接。
Attention的基础上构建了TwoWayAttentionBlock类,再在TwoWayAttentionBlock类的基础上构建了TwoWayAttention类。
和上面一样,我们先来看总体,即TwoWayAttention类。
先看其初始化:

    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
    ) -> None:
        """
        A transformer decoder that attends to an input image using
        queries whose positional embedding is supplied.

        Args:
          depth (int): number of layers in the transformer
          embedding_dim (int): the channel dimension for the input embeddings
          num_heads (int): the number of heads for multihead attention. Must
            divide embedding_dim
          mlp_dim (int): the channel dimension internal to the MLP block
          activation (nn.Module): the activation to use in the MLP block
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        for i in range(depth): #定义并初始化self.layers,depth实例化时为2,此时即为两个TwoWayAttentionBlock
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )
        #self.final_attn_token_to_image即为右下角那个token to image attn
        self.final_attn_token_to_image = Attention(embedding_dim, num_heads,downsample_rate=attention_downsample_rate)
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

可以看到:self.layers中存放了若干个TwoWayAttentionBlock,实例化时depth为2,即为两个,此部分即为图中深橙色部分。self.final_attn_token_to_image即为右下角的那个token to image attn。

再来看下其forword部分:

    def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape
            B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must
            have the same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.

        Returns:
          torch.Tensor: the processed point_embedding
          torch.Tensor: the processed image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # Prepare queries
        queries = point_embedding
        keys = image_embedding

        # 执行图中深橙色部分
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )
		#执行完上面几句代码后得到两个结果,queries为下方输出结果,keys为上方输出结果
        # Apply the final attention layer from the points to the image
        q = queries + point_embedding
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) #执行图中右下角那个token to image attn
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)
		#此时queries就成了已经经过右下方那个token to image attn之后的结果,而keys还未经过两次转置卷积
        return queries, keys

此时forword中的输入image_embeddingimage_pepoint_embedding分别对应前面hs, src = self.transformer(src, pos_src, tokens)中的srcpos_srctokens,看到这里你应该明白了上面所说的“hs是已经经过右下方那个token to image attn之后的结果,而src是上方还未经过2x conv trans的结果”是怎么回事了吧。

然后我们再来看构建TwoWayAttention中用到的TwoWayAttentionBlock,先看它是怎么调用的:

		# 此部分在def __init__中
        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )
            
        #......
        
        #此部分在forword中
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)
        queries = point_embedding
        keys = image_embedding
        
        for layer in self.layers:
     		queries, keys = layer(
		         queries=queries,
		         keys=keys,
		         query_pe=point_embedding,
		         key_pe=image_pe,
     )

我们可以发现:TwoWayAttentionBlock执行forword时的四个输入来源于TwoWayTransformer执行forword时的三个输入,其中queriespoint_embedding得到,point_embedding 由本身得到,keysimage_embedding得到,key_peimage_pe得到。
TwoWayAttentionBlock利用AttentionAttention详见前言中的两个链接,不再赘述)构建了自注意力模块和交叉注意力模块,其初始化代码如下:

super().__init__()
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)

        self.skip_first_layer_pe = skip_first_layer_pe

这里的self.self_attn即为自注意力模块,self.cross_......即为交叉注意力模块。

再看TwoWayAttentionBlock中的forword函数:

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        # Self attention block
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # Cross attention block, tokens attending to image embedding
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # Cross attention block, image embedding attending to tokens
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys

在这里插入图片描述
所有q相关的参数,都可视为下方的输入/输出,k相关的参数,都可视为上方的输入/输出;query_pe和key_pe分别为各自位置编码,如果不太了解还请参考前言中的两个链接以及up主霹雳啪啦的vit讲解。可以看到其实代码和图片中的各部分名字是对应的,先对queries做了一个自注意力,即self.self_attn操作,此处和深橙色中的self attn对应。然后self.cross_attn_token_to_image操作和深橙色中的token to image attn对应、self.mlp和深橙色中的mlp对应、self.cross_attn_image_to_token和深橙色中的image to token attn 对应,两个交叉注意力操作相当于互换q、k位置各做一次注意力。至此TwoWayAttentionBlock就算是结束了,返回了两个值queries, keys,其中queries对应下方紫色线条的输出,keys对应上方绿色线条的输出。注意:TwoWayAttentionBlock输出的queries还未经过右下方的token to image attn,这一步骤是在TwoWayTransformer类中完成的,keys同样没有经过图中的2x conv trans,这一步骤是在MaskDecoder类中完成的。

总结

提示:大抵如此,完结撒花,如有不足或错误之处,还请批评指正。

  • 18
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值