本文主要讲解了sam模型中的MaskDecoder部分。
前言
强烈建议先看明白self-attention。
up主霹雳吧啦的attention博客
up主霹雳吧啦的attention视频
一、MaskDecoder
我们先在segment_anything/modeling/mask_decoder.py
中总览一下MaskDecoder部分:
其初始化为:
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
activation (nn.Module): the type of activation to use when upscaling masks
iou_head_depth (int): the depth of the MLP used to predict mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
#两次转置卷积,用于上采样
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
#图中右边上面那个MLP
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
#图中右边下面那个MLP
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
其中self.transformer = transformer
即为左边深橙色部分;self.iou_token
和self.mask_tokens
不是从前面的输出中得来的,而是凭空出现的、可学习的,类似于vit中的位置编码(不了解的仍然可以参考up主霹雳啪啦的相关视频和博客)。self.output_upscaling
即两次转置卷积,即图中2x conv trans,每次起到两倍上采样的作用,共放大四倍;self.iou_prediction_head
即为iou输出前的那次mlp层,self.output_hypernetworks_mlps
为它上方的那个mlp层;你可能会好奇:右下方的那个token to image attn在哪,它并不在这里,而是被一并写入了左边的transformer代码块里面,后面细说。
forword
部分为:
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
可以看到:其forword过程主要并没有写在forward本身中,而是写在了predict_masks
函数中,然后在forward中调用了predict_masks。重点看predict_masks函数:在经过tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
这一步后,此时即完成了图中output_tokens+prompt_token这一步,而image_embeddings和 image_pe并没有直接传入self.transformer
中而是转化为src和pos_src再传入的。注意:不要把hs, src = self.transformer(src, pos_src, tokens)
这句中self.transformer
中传入的三个参数当作attention里面的q、k、v,这里后面细说,同时注意:此时得到了两个输出hs, src,hs是已经经过右下方那个token to image attn之后的结果,而src是上方还未经过2x conv trans的结果,之后的过程就和图片一样了,从hs中分离出iou部分和mask部分、mlp层、src经过四倍上采样、做一次矩阵乘法…得到各自结果
二、TwoWayTransformer
现在我们已经把橙色图中右边的细枝末节都过了一遍,但左边深色部分我们还没看,现在我们来看一下吧。MaskDecoder初始化的时候,self.transformer = transformer
这句定义了MaskDecoder这个类中的transformer,但我们还不知道这个transformer具体是啥样,于是我们可以去看一下它的实例化;在segment_anything/modeling/mask_decoder.py
中最上方的class MaskDecoder(nn.Module):
按住ctrl、鼠标左击MaskDecoder,查看它在哪里被使用了,我们从这里进入了segment_anything/build_sam.py
,然后可以看到下图:
可以看到这里使用的transformer是TwoWayTransformer
这个类,这个类对应的就是下图中左边深色部分(右上角的x2刚好和TwoWay对应)(其实也不完全对应,上面提到过,这个TwoWayTransformer
其实包括了左边深色部分和右下方的那个token to image attn,更准确地说,深色部分和TwoWayAttentionBlock
类对应)
此时我们可以按住ctrl,鼠标左击TwoWayTransformer
进入其中查看,可以看到:此python文件中定义了三个类:TwoWayTransformer
、TwoWayAttentionBlock
、Attention
,其中Attention
为最基础的注意力,不再描述,不明白的可以参考前言中的两个链接。
在Attention
的基础上构建了TwoWayAttentionBlock
类,再在TwoWayAttentionBlock
类的基础上构建了TwoWayAttention
类。
和上面一样,我们先来看总体,即TwoWayAttention
类。
先看其初始化:
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth): #定义并初始化self.layers,depth实例化时为2,此时即为两个TwoWayAttentionBlock
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
#self.final_attn_token_to_image即为右下角那个token to image attn
self.final_attn_token_to_image = Attention(embedding_dim, num_heads,downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
可以看到:self.layers中存放了若干个TwoWayAttentionBlock
,实例化时depth为2,即为两个,此部分即为图中深橙色部分。self.final_attn_token_to_image
即为右下角的那个token to image attn。
再来看下其forword部分:
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# 执行图中深橙色部分
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
#执行完上面几句代码后得到两个结果,queries为下方输出结果,keys为上方输出结果
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) #执行图中右下角那个token to image attn
queries = queries + attn_out
queries = self.norm_final_attn(queries)
#此时queries就成了已经经过右下方那个token to image attn之后的结果,而keys还未经过两次转置卷积
return queries, keys
此时forword
中的输入image_embedding
、image_pe
、point_embedding
分别对应前面hs, src = self.transformer(src, pos_src, tokens)
中的src
、pos_src
、tokens
,看到这里你应该明白了上面所说的“hs是已经经过右下方那个token to image attn之后的结果,而src是上方还未经过2x conv trans的结果”是怎么回事了吧。
然后我们再来看构建TwoWayAttention
中用到的TwoWayAttentionBlock
,先看它是怎么调用的:
# 此部分在def __init__中
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
#......
#此部分在forword中
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
queries = point_embedding
keys = image_embedding
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
我们可以发现:TwoWayAttentionBlock执行forword
时的四个输入来源于TwoWayTransformer
执行forword
时的三个输入,其中queries
由point_embedding
得到,point_embedding
由本身得到,keys
由image_embedding
得到,key_pe
由image_pe
得到。
TwoWayAttentionBlock
利用Attention
(Attention
详见前言中的两个链接,不再赘述)构建了自注意力模块和交叉注意力模块,其初始化代码如下:
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
这里的self.self_attn
即为自注意力模块,self.cross_......
即为交叉注意力模块。
再看TwoWayAttentionBlock
中的forword
函数:
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
所有q相关的参数,都可视为下方的输入/输出,k相关的参数,都可视为上方的输入/输出;query_pe和key_pe分别为各自位置编码,如果不太了解还请参考前言中的两个链接以及up主霹雳啪啦的vit讲解。可以看到其实代码和图片中的各部分名字是对应的,先对queries做了一个自注意力,即self.self_attn
操作,此处和深橙色中的self attn对应。然后self.cross_attn_token_to_image
操作和深橙色中的token to image attn对应、self.mlp
和深橙色中的mlp对应、self.cross_attn_image_to_token和
深橙色中的image to token attn 对应,两个交叉注意力操作相当于互换q、k位置各做一次注意力。至此TwoWayAttentionBlock
就算是结束了,返回了两个值queries, keys
,其中queries
对应下方紫色线条的输出,keys
对应上方绿色线条的输出。注意:TwoWayAttentionBlock
输出的queries
还未经过右下方的token to image attn,这一步骤是在TwoWayTransformer
类中完成的,keys
同样没有经过图中的2x conv trans,这一步骤是在MaskDecoder
类中完成的。
总结
提示:大抵如此,完结撒花,如有不足或错误之处,还请批评指正。