概述
前段时间看了facebook的分割大模型SAM;论文1比较抽象,所以结合代码,看了一下模型的细节。将自己看到的记录如下,理解的不到位的地方欢迎大家交流。
模型结构
模型的组成有三部分:图像encoder网络+prompt网络+decoder网络
下面引用官方的代码简单分析一下各个模块的机制。
主干网络
对图像编码的网络前向的传播代码如下:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
从上述代码看,主干网络实际上是位置编码+attention的block;是经典的transformer网络结构。
后接一个neck改变输出emeding的的维度
neck的结构是卷积+layernorm的归一化的操作,这在图像的模型是一种常见的卷积操作。
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
prompt网络
prompt模块是作为提示输入的模块。从代码上看输入两种类型:稠密的(mask)和稀疏的(正负的提示点和框的两个顶点);这个模块分别对其编码后,这个地方的编码稀疏的就是一个位置的简单编码。而mask的编码又多经过了几层的卷积。
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
在看代码时,没有发现text的prompt入口,论文中也只是说做了理论研究,所以推测没有释放出代码模型。
解码与输出
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
这个模块是可见接收了之前两个模块的embed ing,有点像分割模型的decoder,经过“output_upscaling”上采样后得到mask;不同的是,通过transformer混合了prompt和图像的编码。
训练过程:
训练的过程没有体现在开源代码中,包括fine-tune
催更训练代码贴
之后有时间将根据论文描述.还原模型的Loss,训练data组织。
参考文献:
[1]: https://ai.facebook.com/research/publications/segment-anything/
[2]:https://github.com/facebookresearch/segment-anything
Segment Anything ↩︎