sagment-anything官方代码使用详解

一. sagment-anything官方例程说明

1. 结果显示函数说明

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

2. SamAutomaticMaskGenerator对象

(1) SamAutomaticMaskGenerator初始化参数

  • model (Sam): 用于掩模预测的Sam模型。
  • points_per_side (int or None): 沿图像一侧要采样的点的数量。总点数为points_per_side 2 ^2 2。如果为None,则point_grids必须提供显式点采样。默认为32
  • points_per_batch (int): 设置模型同时检测的点数。更高的数字可能更快,但使用更多的GPU内存。默认为64
  • pred_iou_thresh (float): [0,1]中的滤波阈值,使用模型的预测掩码质量。默认值为0.88
  • stability_score_thresh (float): [0,1]中的滤波阈值,使用掩码在截断值变化下的稳定性,用于对模型的掩码预测进行二值化。默认值为0.95
  • stability_score_offset (float): 计算稳定性分数时,偏移截止值的量。默认值为1.0
  • box_nms_thresh (float): 非最大抑制用于过滤重复掩码的框IoU截止。默认值为0.7
  • crop_n_layers (int): 如果>0,将对图像的裁剪再次运行掩膜预测。设置要运行的层数,其中每层具有2*i_layer数量的图像裁剪。默认值为0
  • crop_nms_thresh (float): 非最大抑制用于过滤不同物体之间的重复掩码的框IoU截止。默认值为0.7
  • crop_overlap_ratio (float): 设置物体重叠的程度。在第一个裁剪层中,裁剪将重叠图像长度的这一部分。物体较多的后几层会缩小这种重叠。默认值为512 / 1500
  • crop_n_points_downscale_factor (int): 在层n中采样的每侧的点数按比例缩小crop_n_points_downscale_factor n ^n n。默认值为1
  • point_grids (list(np.ndarray) or None): 用于采样的点的显式网格上的列表,归一化为[0,1]。列表中的第n个栅格用于第n个裁剪层。与points_per_side独占。默认值为None
  • min_mask_region_area (int): 如果>0,将应用后处理来移除面积小于min_mask_region_area的掩膜来中断开连接的区域和孔。需要opencv。默认为0
  • output_mode (str): 表单掩码在中返回。可以是binary_maskuncompressed_rlecoco_rlecoco_rle需要pycocotools。对于大分辨率,binary_mask可能会消耗大量内存。默认为'binary_mask'
    “”"

3. SamPredictor对象

(1) 初始化参数

  • model (Sam): 用于掩模预测的Sam模型。

(2) set_image()

说明:
	设置检测的图像
参数:
	image(np.ndarray):用于计算掩码的图像。应为HWC uint8格式的图像,像素值为[0,255]。
	image_format(str):图像的颜色格式,以'RGB''BGR'为单位。

(3) predict()

说明:
	使用当前设置的图像预测给定输入提示的掩码。
参数:
	point_coords(np.ndarray或None):存放指向图像中物体的点的Nx2数组。每个点都以像素为单位(X,Y)。
	point_labels(np.ndarray或None):点提示的长度为N的标签阵列。1表示前景点,0表示背景点。
	box(np.ndarray或None):长度为4的数组,以XYXY格式向模型提供长方体提示。
	mask_input(np.ndarray):输入到模型的低分辨率掩码,通常来自先前的预测迭代。形式为1xHxW,其中对于SAM,H=W=256。
	multimask_output(bool):如果为true,则模型将返回三个掩码。对于不明确的输入提示(如单击),这通常会产生比单个预测更好的掩码。
	                          如果只需要单个遮罩,则可以使用模型的预测质量分数来选择最佳遮罩。对于非模糊提示,例如多个输入提示,
	                          multimask_output=False可以提供更好的结果。
	return_logits(bool):如果为true,则返回未阈值掩码logits,而不是二进制掩码。
返回值:
    (np.ndarray):CxHxW格式的输出掩码,其中C是掩码的数量,(H,W)是原始图像大小。
    (np.ndarray):长度为C的数组,包含模型对每个掩码质量的预测。
    (np.ndarray):形状为CxHxW的数组,其中C是掩码的数量,H=W=256。这些低分辨率logits可以作为掩码输入传递给后续迭代。

二. SamPredictor流程说明

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_predictor = SamPredictor(sam)  # 生成sam预测对象

5. 设置要检测的图像

predictor.set_image(image)

6. 根据不同输入需求对图像进行掩膜预测

(1) 根据输入一个点,输出对于这个点的三个不同置信度的掩膜

input_point = np.array([[250, 187]])
input_label = np.array([1])

# 在'multimask_output=True'(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。
masks, scores, logits = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=True,)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

(2) 通过多个点获取一个对象的掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [273, 259]])
input_label = np.array([1, 1])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(3) 通过设置反向点反选掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [319, 274]])
input_label = np.array([1, 0])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(4) boxes输入生成掩膜

input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(5) 同时输入点与boxes生成掩膜

input_point = np.array([[237, 244]])
input_label = np.array([1])
input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_points(input_point, input_label, plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(6) 多个输入输出不同预测结果

SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。

input_boxes = torch.tensor([
    [228, 230, 280, 276],
    [495, 90, 554, 125],
    [447, 499, 494, 548],
    [162, 346, 214, 390],
], device=predictor.device) #假设这是目标检测的预测结果

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

masks, _, _ = predictor.predict_torch(point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False)

plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

三. SamAutomaticMaskGenerator预测流程

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           points_per_batch=64,
                                           pred_iou_thresh=0.88,
                                           stability_score_thresh=0.95,
                                           stability_score_offset=1.0,
                                           box_nms_thresh=0.7,
                                           crop_n_layers=0,
                                           crop_nms_thresh=0.7,
                                           crop_overlap_ratio=0.34133,
                                           crop_n_points_downscale_factor=1,
                                           point_grids=None,
                                           min_mask_region_area=0,
                                           output_mode='binary_mask')

5. 设置要检测的图像

# 将图像送入推理对象进行推理分割,输出结果为一个列表,其中存的每个字典对象内容为:
# segmentation : 分割出来的物体掩膜(与原图像同大小,有物体的地方为1其他地方为0)
# area : 物体掩膜的面积
# bbox : 掩膜的边界框(XYWH)
# predicted_iou : 模型自己对掩模质量的预测
# point_coords : 生成此掩码的采样输入点
# stability_score : 掩模质量的一个附加度量
# crop_box : 用于以XYWH格式生成此遮罩的图像的裁剪
masks = mask_generator.generate(image)

6. 给分割出来的物体上色,显示分割效果

# 给分割出来的物体上色,显示分割效果
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

在这里插入图片描述

四. SamAutomaticMaskGenerator不同参数下的检测效果

1. points_per_side参数测试

  1. points_per_side=4,检测到9个物体
    在这里插入图片描述

  2. points_per_side=16,检测到211个物体
    在这里插入图片描述

  3. points_per_side=64,检测到683个物体
    在这里插入图片描述

  4. points_per_side=256,检测到872个物体
    在这里插入图片描述

2. pred_iou_thresh参数测试

  1. pred_iou_thresh=1, 检测到1个物体
    在这里插入图片描述
  2. pred_iou_thresh=0.95, 检测到274个物体
    在这里插入图片描述
  3. pred_iou_thresh=0.8, 检测到792个物体
    在这里插入图片描述

3. stability_score_thresh参数测试

  1. stability_score_thresh=1,检测到0个物体
    kjui
  2. stability_score_thresh=0.95,检测到683个物体
    在这里插入图片描述
  3. stability_score_thresh=0.95,检测到764个物体
    在这里插入图片描述

4. box_nms_thresh参数测试

  1. box_nms_thresh=1,检测到4680个物体
    在这里插入图片描述

  2. box_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. box_nms_thresh=0.4,检测到621个物体
    在这里插入图片描述

  4. box_nms_thresh=0.1,检测到458个物体
    在这里插入图片描述

  5. box_nms_thresh=0,检测到201个物体
    在这里插入图片描述

5. crop_nms_thresh参数测试

  1. crop_nms_thresh=1,检测到683个物体
    在这里插入图片描述

  2. crop_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. crop_nms_thresh=0.1,检测到683个物体
    在这里插入图片描述

  • 25
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AoDeLuo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值