SAM
首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!
SAM工作就是用新的prompt engineer+预训练大模型的范式来对图像进行分割,以实现zero-shot(旧的范式是pretrain+finetune)
整个SAM架构可以分成三个大部分,image encoder部分、prompt encoder部分、mask decoder部分。下面一一介绍。
Image Encoder
SAM的image encoder部分用的是MAE预训练的ViT,ViT这里就不做介绍了。原始图像被等比和padding的缩放到1024大小,采用kernel size为16,stride为16的卷积将图像离散化为64×64×768的向量,铺平后进入transformer encoder,输出的向量再通过两层的卷积压缩到embedding dimension为256。
image encoder这一部分的计算以及存储消耗是非常大的,在META官方的demo中,image embedding的计算也是在云端服务器中进行的。所以要实现模型轻量化,对这一部分需要做改进。
prompt encoder
在META官方的demo中,可以通过给定一个点位(point)来进行语义分割,如下图所示。
也可以框选一个区域,来进行语义分割,如下图所示。
此外,论文中还提到text prompt。这个功能在demo中没有展现,个人理解就是给一个我想要分割的区域的描述,SAM根据描述进行相应区域的分割。
上面说到的三种prompts在论文中归类为稀疏类prompt(sparse prompt)。point和box(左上角的点&右下角的点)采用position embedding(transformer里的东西,是一种用sines和cosines组成的编码,能够表示一个东西的相对位置和顺序关系)+learnable cls embeddings作为embedding;(这个部分可以看一下代码)
text prompt同样也是稀疏类prompt,但显然不能用pe来表示它。SAM中对应于text的encoder是CLIP架构中的text encoder,具体可以看CLIP的相关内容。
还有一个prompt是mask,采用卷积神经网络进行下采样后和image embedding进行element-wise相加(使得,就是1+1=2的加,反正都挺玄学的)
mask decoder
下图是论文中给出的mask decoder的结构
相信大部分人和我一样,乍一眼看,一脸懵逼,这么多箭头,而且论文中对它的描述也很少。那我们从左往右来分析。
image embedding和prompt embedding就是上面提到的prompt部分的内容。而output tokens前面并没有提到,其实看过ViT的同学应该对这个玩意儿不陌生,VIT做的是分类任务,在image embedding的最前面加了一个cls token,在好几层的self attention之后,输出的这个cls token就是对应的目标类别。这里也是同理,SAM做的是语义分割任务,但是输出不止一个mask,如下图所示。
这个应该是针对于point prompt来说的,拿论文中这个剪刀举例。我point点在剪刀柄上,我想要分割的区域可能会是上面三种的其中一种,也就是“全部”、“部分”、“子部分”。那么根据什么来展示出最后的输出呢,就涉及到这个output tokens,一个output mask对应一个output tokens,还有一个IoU prediction head来选择三个mask中它认为最好的输出(这个IoU prediction head是模型中的一个learnable分支,在训练模型时根据GT来训练)。
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
# output_tokens由IoU_token和mask_token组成,将他们进行拼接
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
#将原来2维的output_tokens空间化后按照sparse_prompt_embedding的维度进行扩展(sparse_prompt_embedding的维度是什么呢?)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
#将扩展后的output_tokens和sparse_prompt_embeddings进行拼接
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
# 我的理解就是有几个tokens就把image_embedding扩充几次,这个tokens应该和prompt的类型有关
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
# 把image_embeddings和mask_embeddings加上
src = src + dense_prompt_embeddings
# 和ViT一样,加上learnable_position_encoding
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
# token to image attention(先按下不表)
#这里的transformer整合了右下角的token to image attn.后面再说
hs, src = self.transformer(src, pos_src, tokens)
# 从hs中分离出两个token
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
# 梯形的那个上采样
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
那么上面先按下不表的内容我们下面就可以拿出来说了,在上面的代码中的transformer并不只是正常框架的transformer,里面包含了mask decoder中左边深色框住部分以及右下角的token to image attn部分,所以上面把image embedding、output token、prompt token处理完了之后,一个transformer出来就直接给2xconv和分output_token了。
在论文中提到,mask decoder用了TwoWayTransformer。在代码中,我们可以在segment-anything/segment_anything/build_sam.py路径下找到MaskDecoder类的相关代码,如下图所示。
可以看到,transformer调用的是TwoWayTransfromer类,depth为2就是结构示意图中的×2。
根据segment-anything/segment_anything/modeling/transformer.py路径可以找到TwoWayTransformer的定义。相同目录下也能找到TwoWayAttentionBlock,上面也提到了,transformer包括深色框选部分以及右下角的token to image attn.,那么这个TwoWayAttentionBlock就是pure的深色框选部分的代码。
先来看整体的TwoWayTransformer的代码
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
# 这两个注释之间的内容就是实现框架的结构,depth为2,start
for i in range(depth):
self.layers.append(
# 深色框选部分
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
# 显然这个就是右下角的token to image attn.
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
# 到此能说明结构走向 end
def forward(
# 这部分就是前面两个encoder得到的结果,对于point prompt和box prompt来说其实都是点,无非前者是一个点,后者是两个点
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
# 显然,image embedding是四维的,这里是将第三维开始展开,将后面的维度转化为一维
# 那么具体到image embedding,第三维和第四维也就是high和width
# 然后把三个维从BxCxHxW调整为BxHWxC
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding # 为啥要加两回?
k = keys + image_pe # 合理的,image_embedding直接加上position_encoding
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries) # 这里的norm_final_attn就是一个Layernorm
# 这里还要再看一下
return queries, keys
好,看完整个的结构,再来单独看深色框选部分的代码,也就是TwoWayAttentionBlock
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
到这里,SAM的模型结构差不多就结束了,但是离理解SAM还差一个核心,也就是数据的设计,这在论文中也花了非常大的篇幅进行介绍,也是理解SAM非常重要的一环,应为模型中的很多设计都和数据的设计有关。这里先挖坑(20240327),希望有时间能来填上。