Segment Anything Model代码讲解(三)之mask_decoder

文章介绍了一个名为MaskDecoder的类,它利用Transformer架构进行掩码预测。模型接收图像特征、点和框的嵌入等输入,通过Transformer编码和超网络预测掩码的质量。此外,还包含一个MLP类用于预测质量分数。模型还包括上采样和质量评分的计算过程。
摘要由CSDN通过智能技术生成

在mask_decoder中,实现的是一个基于Transformer的掩码预测模型。主要包含一个MaskDecoder类和一个MLP类。

  • class MaskDecoder(nn.Module)

    • def init
    • def forward
    • def predict_masks
  • class MLP(nn.Module)

MaskDecoder类是掩码预测模型的主体,它接收图像特征、点和框的嵌入以及掩码嵌入作为输入,通过Transformer将它们编码为掩码的表示形式。同时,它还有一些后续处理的步骤,如通过超网络预测每个掩码的质量,以及通过卷积转置层将掩码上采样到原图尺寸。最终,该类将返回预测的掩码和质量。(在这个模型中,超网络(Hyper-network)用于预测每个掩码的质量。它输入每个掩码的表示形式,并输出一个长度等于掩码数量加一的向量,其中第一个元素是无用的预测(即掩码为零的情况)。接着,这个向量将作为加权掩码特征,用于计算掩码的质量分数。在实现上,超网络使用一个小型的MLP来完成这个任务。它使用掩码的表示作为输入,然后输出每个掩码的质量分数。)

MLP类是一个多层感知机,用来预测掩码质量和超网络的输入。它包含多个线性层,每个线性层后跟一个ReLU激活函数。它的输出可以通过sigmoid_output的参数控制是否进行sigmoid归一化。

还有一些辅助的函数和模块,如LayerNorm2d实现了二维通道维度的Layer Normalization,nn.GELU用于激活函数的GELU,以及一些数据处理的辅助函数。

import torch
from torch import nn
from torch.nn import functional as F

from typing import List, Tuple, Type

from .common import LayerNorm2d


class MaskDecoder(nn.Module):
    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
    ) -> None:
        """
        使用Transformer架构,根据图像和提示嵌入(prompt embeddings)预测掩码。 
        参数: 
        - transformer_dim(int):Transformer的通道维度。 
        - transformer(nn.Module):用于预测掩码的Transformer。 
        - num_multimask_outputs(int):在消除模糊掩码时要预测的掩码数量。 
        - activation(nn.Module):在上采样掩码时要使用的激活函数类型。 
        - iou_head_depth(int):用于预测掩码质量的MLP的深度。 
        - iou_head_hidden_dim(int):用于预测掩码质量的MLP的隐藏维度。
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)
        self.num_mask_tokens = num_multimask_outputs + 1
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
            activation(),
        )
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

        self.iou_prediction_head = MLP(
            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
        )

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
       根据图像和提示嵌入预测掩码。 参数: 
       image_embeddings(torch.Tensor):图像编码器生成的嵌入。 
       image_pe(torch.Tensor):形状与image_embeddings相同的位置编码。 
       sparse_prompt_embeddings(torch.Tensor):点和框的嵌入。 
       dense_prompt_embeddings(torch.Tensor):掩码输入的嵌入。 
       multimask_output(bool):是否返回多个掩码或一个掩码。 
       返回结果: 
       torch.Tensor:批量预测的掩码。 
       torch.Tensor:批量预测的掩码质量。
        """
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred


# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        sigmoid_output: bool = False,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.sigmoid_output = sigmoid_output

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x
  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Segment Anything Model可以指很多不同的模型,因此调用方式也不同。以下是使用Mask R-CNN进行实例分割的示例代码: ``` import os import sys import random import math import numpy as np import skimage.io import matplotlib import matplotlib.pyplot as plt # Root directory of the project ROOT_DIR = os.path.abspath("../") # Import Mask RCNN sys.path.append(ROOT_DIR) # To find local version of the library from mrcnn.config import Config from mrcnn import model as modellib, utils # Directory to save logs and trained model MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Path to trained weights file COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5") # Download COCO trained weights from Releases if needed if not os.path.exists(COCO_MODEL_PATH): utils.download_trained_weights(COCO_MODEL_PATH) # Directory of images to run detection on IMAGE_DIR = os.path.join(ROOT_DIR, "images") class InferenceConfig(Config): # Set batch size to 1 since we'll be running inference on # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU GPU_COUNT = 1 IMAGES_PER_GPU = 1 NAME = "segment_anything_model" # Number of classes (including background) NUM_CLASSES = 1 + 80 # COCO has 80 classes config = InferenceConfig() config.display() # Create model object in inference mode. model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config) # Load weights trained on MS-COCO model.load_weights(COCO_MODEL_PATH, by_name=True) # COCO Class names # Index of the class in the list is its ID. For example, to get ID of # the teddy bear class, use: class_names.index('teddy bear') class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] # Load a random image from the images folder file_names = next(os.walk(IMAGE_DIR))[2] image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names))) # Run detection results = model.detect([image], verbose=1) # Visualize results r = results[0] visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores']) ``` 这段代码会在指定的文件夹中随机选取一张图片进行实例分割,并将结果可视化显示。你需要将代码中的`ROOT_DIR`、`MODEL_DIR`、`COCO_MODEL_PATH`、`IMAGE_DIR`等路径修改为你自己的路径。同时,如果你使用的是其他的Segment Anything Model,那么需要根据具体的模型进行修改。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

果粒橙_LGC

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值