SAM
首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!
SAM工作就是用新的prompt engineer+预训练大模型的范式来对图像进行分割,以实现zero-shot(旧的范式是pretrain+finetune)
整个SAM架构可以分成三个大部分,image encoder部分、prompt encoder部分、mask decoder部分。下面一一介绍。
Image Encoder
SAM的image encoder部分用的是MAE预训练的ViT,ViT这里就不做介绍了。原始图像被等比和padding的缩放到1024大小,采用kernel size为16,stride为16的卷积将图像离散化为64×64×768的向量,铺平后进入transformer encoder,输出的向量再通过两层的卷积压缩到embedding dimension为256。
image encoder这一部分的计算以及存储消耗是非常大的,在META官方的demo中,image embedding的计算也是在云端服务器中进行的。所以要实现模型轻量化,对这一部分需要做改进。
prompt encoder
在META官方的demo中,可以通过给定一个点位(point)来进行语义分割,如下图所示。
也可以框选一个区域,来进行语义分割,如下图所示。
此外,论文中还提到text prompt。这个功能在demo中没有展现,个人理解就是给一个我想要分割的区域的描述,SAM根据描述进行相应区域的分割。
上面说到的三种prompts在论文中归类为稀疏类prompt(sparse prompt)。point和box(左上角的点&右下角的点)采用position embedding(transformer里的东西,是一种用sines和cosines组成的编码,能够表示一个东西的相对位置和顺序关系)+learnable cls embeddings作为embedding;(这个部分可以看一下代码)
text prompt同样也是稀疏类prompt,但显然不能用pe来表示它。SAM中对应于text的encoder是CLIP架构中的text encoder,具体可以看CLIP的相关内容。
还有一个prompt是mask,采用卷积神经网络进行下采样后和image embedding进行element-wise相加(使得,就是1+1=2的加,反正都挺玄学的)
mask decoder
下图是论文中给出的mask decoder的结构
相信大部分人和我一样,乍一眼看,一脸懵逼,这么多箭头,而且论文中对它的描述也很少。那我们从左往右来分析。
image embedding和prompt embedding就是上面提到的prompt部分的内容。而output tokens前面并没有提到,其实看过ViT的同学应该对这个玩意儿不陌生,VIT做的是分类任务,在image embedding的最前面加了一个cls token,在好几层的self attention之后,输出的这个cls token就是对应的目标类别。这里也是同理,SAM做的是语义分割任务,但是输出不止一个mask,如下图所示。
这个应该是针对于point prompt来说的,拿论文中这个剪刀举例。我point点在剪刀柄上,我想要分割的区域可能会是上面三种的其中一种,也就是“全部”、“部分”、“子部分”。那么根据什么来展示出最后的输出呢,就涉及到这个output tokens,一个output mask对应一个output tokens,还有一个IoU prediction head来选择三个mask中它认为最好的输出(这个IoU prediction head是模型中的一个learnable分支,在训练模型时根据GT来训练)。
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor