【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块ProEnco网络解析

【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析

Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。本博客将讲解Prompt encoder模块的深度学习网络代码。


前言

在详细解析SAM代码之前,首要任务是成功运行SAM代码【win10下参考教程】,后续学习才有意义。本博客讲解Prompt encoder模块的深度网络代码,不涉及其他功能模块代码。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。


PromptEncoder网络简述

SAM模型关于ProEnco网络的配置

博主以sam_vit_b为例,详细讲解ViT网络的结构。
代码位置:segment_anything/build_sam.py

def build_sam_vit_b(checkpoint=None):
    return _build_sam(
        # 图像编码channel
        encoder_embed_dim=768,
        # 主体编码器的个数
        encoder_depth=12,
        # attention中head的个数
        encoder_num_heads=12,
        # 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)
        encoder_global_attn_indexes=[2, 5, 8, 11],
        # 权重
        checkpoint=checkpoint,
    )

sam模型中prompt_encoder模块初始化

prompt_encoder=PromptEncoder(
    # 提示编码channel(和image_encoder输出channel一致,后续会融合)
    embed_dim=prompt_embed_dim,
    # mask的编码尺寸(和image_encoder输出尺寸一致)
    image_embedding_size=(image_embedding_size, image_embedding_size),
    # 输入图像的标准尺寸
    input_image_size=(image_size, image_size),
    # 对输入掩码编码的通道数
    mask_in_chans=16,
),

ProEnco网络结构与执行流程

Prompt encoder源码位置:segment_anything/modeling/prompt_encoder.py
ProEnco网络(PromptEncoder类)结构参数配置。

def __init__(
    self,
    embed_dim: int,                         # 提示编码channel
    image_embedding_size: Tuple[int, int],  # # mask的编码尺寸
    input_image_size: Tuple[int, int],      # 输入图像的标准尺寸
    mask_in_chans: int,                     # 输入掩码编码的通道数
    activation: Type[nn.Module] = nn.GELU,  # 激活层
) -> None:
    super().__init__()
    self.embed_dim = embed_dim              # 提示编码channel
    self.input_image_size = input_image_size                # 输入图像的标准尺寸
    self.image_embedding_size = image_embedding_size        # mask的编码尺寸
    self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
    self.num_point_embeddings: int = 4                      # 4个点:正负点,框的俩个点
    point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]   # 4个点的嵌入向量
    # nn.ModuleList它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器
    self.point_embeddings = nn.ModuleList(point_embeddings)                     # 4个点的嵌入向量添加到网络
    self.not_a_point_embed = nn.Embedding(1, embed_dim)                         # 不是点的嵌入向量
    self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])           # mask的输入尺寸
    self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
        nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans // 4),
        activation(),
        nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans),
        activation(),
        nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
    )
    self.no_mask_embed = nn.Embedding(1, embed_dim)                         # 没有mask输入时 嵌入向量

SAM模型中ProEnco网络结构如下图所示:

ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:

  1. Embed_Points:标记点编码(标记点由点转变为向量)
  2. Embed_Boxes:标记框编码(标记框由点转变为向量)
  3. Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
def forward(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 获得 batchsize  当前predict为1
    bs = self._get_batch_size(points, boxes, masks)
    
    # -----sparse_embeddings----
    sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
    if points is not None:
        coords, labels = points
        point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
        sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
    if boxes is not None:
        box_embeddings = self._embed_boxes(boxes)
        sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
    # -----sparse_embeddings----
    
    # -----dense_embeddings----
    if masks is not None:
        dense_embeddings = self._embed_masks(masks)
    else:
        dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
            bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
        )
    # -----dense_embeddings----
    
    return sparse_embeddings, dense_embeddings

获取batchsize

def _get_batch_size(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> int:
    if points is not None:
        return points[0].shape[0]
    elif boxes is not None:
        return boxes.shape[0]
    elif masks is not None:
        return masks.shape[0]
    else:
        return 1

获取设备型号

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

ProEnco网络基本步骤代码详解

Embed_Points


标记点预处理,将channel由2变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

2:坐标(h,w)
embed_dim:提示编码的channel

Embed_Points结构如下图所示:

def _embed_points(
    self,
    points: torch.Tensor,
    labels: torch.Tensor,
    pad: bool,
) -> torch.Tensor:
    # 移到像素中心
    points = points + 0.5
    # points和boxes联合则不需要pad
    if pad:
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1
        points = torch.cat([points, padding_point], dim=1)                          # B,N+1,2
        labels = torch.cat([labels, padding_label], dim=1)                          # B,N+1
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # B,N+1,2f
    # labels为-1是非标记点,设为非标记点权重
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    # labels为0是背景点,加上背景点权重
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    # labels为1的目标点,加上目标点权重
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

个人理解:pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。

Embed_Boxes


标记框预处理,将channel由4到2再变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

4:坐标(h1,w1,h2,w2) -->起始点与末位点
2:坐标(h,w)–>4 reshape 成 2×2
embed_dim:提示编码的channel

Embed_Boxes结构如下图所示:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    # 移到像素中心
    boxes = boxes + 0.5
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    #
    # 目标框起始点的和末位点分别加上权重
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    return corner_embedding

个人理解:boxes reshape 后 batchsize是会增加的,B,N,4–>BN,2,2
因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个。

Embed_Masks


mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。
Embed_Masks结构如下图所示:

def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
    # mask下采样4倍
    mask_embedding = self.mask_downscaling(masks)
    return mask_embedding
# 在PromptEncoder的__init__定义
self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans // 4),
    activation(),
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans),
    activation(),
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )

假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask。

# 在PromptEncoder的forward定义
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
    bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)

PositionEmbeddingRandom

用于将标记点和标记框的坐标进行提示编码预处理。

def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
    super().__init__()
    if scale is None or scale <= 0.0:
        scale = 1.0
    # 理解为模型的常数 [2,f]
    self.register_buffer(
        "positional_encoding_gaussian_matrix",
        scale * torch.randn((2, num_pos_feats)),
    )

将标记点的坐标具体的位置转变为[0~1]之间的比例位置

def forward_with_coords(
    self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
    coords = coords_input.clone()
    # 将坐标位置缩放到[0~1]之间
    coords[:, :, 0] = coords[:, :, 0] / image_size[1]
    coords[:, :, 1] = coords[:, :, 1] / image_size[0]
    # B,N+1,2-->B,N+1,2f
    return self._pe_encoding(coords.to(torch.float))

标记点位置编码

因为sin和cos,编码的值归一化至 [-1,1],源码注释是[0,1],博主经过实验发现注释不对

def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
    # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
    coords = 2 * coords - 1
    # B,N+1,2 × 2,f --> B,N+1,f
    coords = coords @ self.positional_encoding_gaussian_matrix
    coords = 2 * np.pi * coords
    # outputs d_1 x ... x d_n x C shape
    # B,N+1,2f
    return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

总结

尽可能简单、详细的介绍SAM中Prompt encoder模块的ProEnco网络的代码。后续会讲解SAM的其他模块的代码。

  • 13
    点赞
  • 58
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
好的,以下是一个使用 PyTorch 实现深度学习图像分割的完整实例: 1. 数据集准备 首先,需要准备适合模型训练的数据集,可以使用一些公共数据集,如 PASCAL VOC、COCO 等。可以使用 torchvision 库中的数据加载器来加载数据集。 2. 定义模型 接下来,需要定义一个模型来执行图像分割任务。可以使用现有的预训练模型,如 UNet、FCN、SegNet 等,也可以自己设计一个模型。在 PyTorch 中,可以使用 nn.Module 类定义一个模型。 3. 定义损失函数 在深度学习中,通常使用损失函数来衡量模型的预测结果与真实结果之间的差异。对于图像分割任务,可以使用交叉熵损失函数或 Dice 损失函数等。在 PyTorch 中,可以使用 nn.CrossEntropyLoss 或 DiceLoss 来定义损失函数。 4. 定义优化器 为了训练模型,需要定义一个优化器来更新模型的参数。可以使用常见的优化器,如 SGD、Adam 等。在 PyTorch 中,可以使用 torch.optim 来定义优化器。 5. 训练模型 有了数据集、模型、损失函数和优化器,可以开始训练模型了。在 PyTorch 中,可以使用 DataLoader 来批量加载数据,使用模型的 forward 方法来进行前向传播,使用损失函数来计算损失,使用优化器来更新模型参数。 6. 测试模型 训练完成后,需要测试模型的性能。可以使用测试数据集来测试模型,并计算各种评估指标,如准确率、召回率、F1 值等。在 PyTorch 中,可以使用模型的 eval 方法来测试模型。 以上就是一个使用 PyTorch 实现深度学习图像分割的完整实例。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值