一、python部署SAM
首先需要检查编译器python >= 3.8
,并安装pytorch >= 1.7 , torchvision >= 0.8
,我的电脑是Windows,直接执行command:
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
安装完成后,可以再安装一些基础依赖,后面处理掩码、保存格式可能用到。
pip install opencv-python pycocotools matplotlib onnxruntime onnx
安装SAM模型可以选择从GitHub上直接克隆,或者在python环境中直接pip install。
pip install git+https://github.com/facebookresearch/segment-anything.git
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
除此之外,还要安装官方提供的预训练模型:
二、SAM学习
SAM的基本架构
:
可以看出SAM有三个核心部分:
- image encoder:它负责处理原始图像,将其编码为向量,经过一系列的缩放、卷积、transformer encoder、压缩操作,最终得到图像嵌入(254 * 64 * 64);
- prompt encoder:SAM提供了三种可选的prompt操作,包括点、框、文本,经过encoder编码后输出sparse_embeddings;此外也可以接受mask作为输入,输出dense_embeddings;
- mask decoder:将两个encoder得到的embedding进行融合得到融合后的特征,从而预测掩码。在这之中使用了很多transformer,利用了其中的自注意力机制和交叉注意力机制
其中,关于prompt里的text形式并未在源码中公开。
SAM提供了三种模型:这三种模型都是通过
_build_sam()
方法构造,在参数方面有些不同,比如图像嵌入的维度等等。
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
该方法完成了SAM模型的初始化,和权重的加载。从该方法中能够看到,SAM确实是由上述的三个主要模块组成。至于其中sam = Sam(...)
的具体代码还要在/segment-anything/modeling/sam.py
中查看。
在SAM进行predict时,首先要求输入的照片类型为RGB
。
如果图片类型不是
RGB / BGR
类型则不予处理。(BGR图片会将其处理为RGB类型)
并通过set_torch_image
方法对图片尺寸和通道顺序进行调整使其满足image_encoder的输入要求,并获得图像嵌入。
在本项目中,我们考虑采用的是box的方法进行息肉的提取,并计划优化算法,使其不用prompt也能自动检测并分割出息肉的位置信息。