【自用】SAM模型论文笔记与复现代码(segment-anything-model)

总模型结构

一个prompt encoder,对提示进行编码,image encoder对图像编码,生成embedding, 最后融合2个encoder,再接一个轻量的mask decoder,输出最后的mask。

模型结构示意图:

流程图:

模型的结构如上图所示. prompt会经过prompt encoder, 图像会经过image encoder。然后将两部分embedding经过一个轻量化的mask decoder得到融合后的特征。encoder部分使用的都是已有模型,decoder使用transformer。

image encoder

利用MAE(Masked AutoEncoder)预训练的ViT模型,对每张图片只处理一次,且在prompt encoder之前进行。输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就填充,得到(c,1024,1024)的图像,经过image encoder,得到对图像16倍下采样的feature,大小为(256,64,64)。

prompt encoder

prompt encoder结构图:

分为两类:稀疏与密集

稀疏:
  • point:使用position encodings
  • box:使用position encodings
  • text:使用CLIP作为encoder
密集:
  • mask:使用卷积作为encoder

mask decoder

  • prompt self-attention
  • cross-attention(从prompt到image和从image到prompt)

valid mask(模型输出)

  • 解决混淆的输入: 对于一个prompt,模型会输出3个mask,实际上也可以输出更多的分割结果,3个可以看作一个物体的整体、部分、子部分,基本能满足大多数情况。使用IOU的方式,排序mask。在反向传播时,参与计算的只有loss最小的mask相关的参数.
  • 高效: 这里主要指的是prompt encodermask decoder。在web浏览器上,CPU计算只用约50ms

SAM模型复现

环境:

python 3.8.10
pytorch 1.11.0
cuda 11.3

环境安装

git clone https://github.com/facebookresearch/segment-anything
pip install opencv-python matplotlib
pip install -e .
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth #下载SAM_VIT-H模型

定义用于可视化的工具函数

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    
    
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

可视化原图片

image = cv2.imread('R.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(14,14))
plt.imshow(image)
plt.axis('on')
plt.show()

原图片:

加载SAM模型

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)

点作为prompt

单点
input_point = np.array([[430, 605]])
input_label = np.array([1])
plt.figure(figsize=(14,14))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

使用SAM模型进行分割,并输出模型分割出的3个mask

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True, #`multimask_output=True`表示是否输出三个mask结果
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(14,14))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  

多点(使用先前单点输出的mask作为mask prompt)
仅前景点
input_point = np.array([[430, 605],[520, 650]])
input_label = np.array([1, 1]) #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  #选择先前分数最高的mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 

前景点+后景点
input_point = np.array([[430, 605],[520, 650], [520,500]])
input_label = np.array([1, 1, 0])  #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 

矩形框作为prompt

单个矩形框
input_box = np.array([730, 105, 1030, 315])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('on')
plt.show()

多个矩形框(需要使用transform.apply_boxes_torch方法进行转换)
input_boxes = torch.tensor([
    [730, 105, 1030, 315],
    [970, 155, 1025, 250]
], device=predictor.device)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('on')
plt.show()

自动分割

from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 

输出mask数量
178

开始调参

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100, 
)
masks_2 = mask_generator_2.generate(image)
print(len(masks_2))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks_2)
plt.axis('on')
plt.show() 

输出mask数量
335

  • 26
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
lio-sam是一个开源项目,是LIO(Linux内核iSCSI target)模块的一个分支。它是专门为高性能和可扩展性而设计的iSCSI目标代码。 lio-sam项目的主要目标是提供一个高性能的iSCSI目标,同时保持Linux kernel的稳定性和可靠性。它在传输层使用Scst(SCSI target实现)和LIO(Linux iSCSI实现)的组合,并有一些优化以提高性能。它还支持各种iSCSI功能,如CHAP认证、数据压缩和IPsec等。 代码阅读lio-sam对Linux内核和iSCSI有一定的了解是很有帮助的。lio-sam使用了一些Linux内核的机制,如工作队列和内存管理。了解这些机制将有助于理解lio-sam的实现原理和性能优化技巧。 在阅读lio-sam代码时,可以关注以下几个方面: 1. LIO模块的初始化和配置:lio-sam在加载模块时进行一些初始化工作,包括创建Scst的实例和配置iSCSI target。了解这些步骤可以帮助理解lio-sam的工作流程和配置方式。 2. iSCSI连接管理:lio-sam负责管理iSCSI连接,包括连接的建立、维护和中断。了解连接管理的实现原理可以帮助理解lio-sam如何处理多个客户端的连接和请求。 3. SCSI命令处理:lio-sam的核心功能是处理SCSI命令。了解lio-sam如何解析SCSI命令、调用底层存储设备和返回响应可以帮助理解其工作原理和性能优化方法。 4. 性能优化技巧:lio-sam的设计目标之一是提高性能。代码中可能包含一些性能优化技巧,如批量处理、IO调度和缓存管理等。了解这些技巧可以帮助优化自己的应用程序。 需要注意的是,代码阅读是一项耗时耗力的工作,需要具备一定的编程和系统知识。在阅读代码时,可以结合官方文档、论坛和社区来获取更多的信息和帮助。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值