创新项目实训(3)——SAM模型学习 & 部署

一、python部署SAM

首先需要检查编译器python >= 3.8,并安装pytorch >= 1.7 , torchvision >= 0.8 ,我的电脑是Windows,直接执行command:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

安装完成后,可以再安装一些基础依赖,后面处理掩码、保存格式可能用到。

pip install opencv-python pycocotools matplotlib onnxruntime onnx

安装SAM模型可以选择从GitHub上直接克隆,或者在python环境中直接pip install。

pip install git+https://github.com/facebookresearch/segment-anything.git
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

除此之外,还要安装官方提供的预训练模型:

在这里插入图片描述

二、SAM学习

SAM的基本架构
在这里插入图片描述

可以看出SAM有三个核心部分:

  • image encoder:它负责处理原始图像,将其编码为向量,经过一系列的缩放、卷积、transformer encoder、压缩操作,最终得到图像嵌入(254 * 64 * 64);
  • prompt encoder:SAM提供了三种可选的prompt操作,包括点、框、文本,经过encoder编码后输出sparse_embeddings;此外也可以接受mask作为输入,输出dense_embeddings;
  • mask decoder:将两个encoder得到的embedding进行融合得到融合后的特征,从而预测掩码。在这之中使用了很多transformer,利用了其中的自注意力机制和交叉注意力机制

其中,关于prompt里的text形式并未在源码中公开。



SAM提供了三种模型:在这里插入图片描述这三种模型都是通过_build_sam()方法构造,在参数方面有些不同,比如图像嵌入的维度等等。

def _build_sam(
    encoder_embed_dim,
    encoder_depth,
    encoder_num_heads,
    encoder_global_attn_indexes,
    checkpoint=None,
):
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size
    sam = Sam(
        image_encoder=ImageEncoderViT(
            depth=encoder_depth,
            embed_dim=encoder_embed_dim,
            img_size=image_size,
            mlp_ratio=4,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            num_heads=encoder_num_heads,
            patch_size=vit_patch_size,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=encoder_global_attn_indexes,
            window_size=14,
            out_chans=prompt_embed_dim,
        ),
        prompt_encoder=PromptEncoder(
            embed_dim=prompt_embed_dim,
            image_embedding_size=(image_embedding_size, image_embedding_size),
            input_image_size=(image_size, image_size),
            mask_in_chans=16,
        ),
        mask_decoder=MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
        ),
        pixel_mean=[123.675, 116.28, 103.53],
        pixel_std=[58.395, 57.12, 57.375],
    )
    sam.eval()
    if checkpoint is not None:
        with open(checkpoint, "rb") as f:
            state_dict = torch.load(f)
        sam.load_state_dict(state_dict)
    return sam

该方法完成了SAM模型的初始化,和权重的加载。从该方法中能够看到,SAM确实是由上述的三个主要模块组成。至于其中sam = Sam(...)的具体代码还要在/segment-anything/modeling/sam.py中查看。



在SAM进行predict时,首先要求输入的照片类型为RGB
在这里插入图片描述如果图片类型不是RGB / BGR类型则不予处理。(BGR图片会将其处理为RGB类型)

并通过set_torch_image方法对图片尺寸和通道顺序进行调整使其满足image_encoder的输入要求,并获得图像嵌入。在这里插入图片描述

在本项目中,我们考虑采用的是box的方法进行息肉的提取,并计划优化算法,使其不用prompt也能自动检测并分割出息肉的位置信息。

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值