【SegmentAnything实战——推理1】交互式选取prompt points完成SAM推理自己的数据集

写在前面

网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。

还可以参考本人其他的交互式选取prompt提示的文章
【SegmentAnything实战——推理2】交互式选取prompt boxes完成SAM推理自己的数据集
【SegmentAnything实战——推理3】交互式选取prompt boxes和prompt points完成SAM推理自己的数据集

参考宝藏博主的blog:

【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法


图片来源:

https://www.photophoto.cn/sucai/20184408.html


效果+代码

1. 单张图片上多个prompt points进行预测

  鼠标左键点击为前景点(红色), 右键点击为背景点(蓝色),关闭图片开始推理

1.1 选点

在这里插入图片描述

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
'''
步骤1: 查看测试图片显示前景和背景的标记点
      鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''
# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
    if event.button == 1:  # 鼠标左键
        label = 1
    elif event.button == 3:  # 鼠标右键
        label = 0
    else:
        return
    x = event.xdata
    y = event.ydata
    # 将点击的点的坐标和类型保存到列表中
    points_coords.append([x, y])
    labels_coords.append(label)
    # 在图像上绘制点击的点
    plt.title("choose pos_points and neg_points")
    plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
    plt.draw()
    # 打印保存的点坐标和类型
    print(f"坐标:({x}, {y}),类型:{label}")



def show_pic_choose_points(img):
    global points_coords
    global labels_coords
    # 创建一个新的图像窗口并显示图像
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.set_title("choose pos_points and neg_points")
    # 定义存储点坐标和类型的列表
    points_coords = []
    labels_coords = []
    # 绑定鼠标点击事件处理函数
    cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
    # 显示图像和等待用户点击点
    plt.show()
    return points_coords, labels_coords
1.2 multimask_output = False时,只输出一个mask

在这里插入图片描述


1.3 multimask_output = True时,输出三个mask

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# =============选点
B = 1   # ♥
# ============预测
# 单张points图片只需要predict 多张points则需要predict_torch
if B == 1:
    # 选点
    points, labels = show_pic_choose_points(image)
    input_points, input_labels = np.array(points.copy()), np.array(labels.copy())  # input_points(n,2) input_labels(n,)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,     # 形式 [[x,y], [x,y],...] (n,2)
        point_labels=input_labels,      # 形式 [0, 1, ...] (n, )
        multimask_output=multimask_output_bool,  # ♥
    )
    # =============显示预测
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_points, input_labels, plt.gca())
        plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

2. 多张图片上多个prompt points进行预测

  要注意,每次选择点时,前景点和背景点的总点数必须每张图片是一样的
2.1 选点(以选两张点为例)

在这里插入图片描述在这里插入图片描述

2.2 预测
else:
    Batch_points = []
    Batch_labels = []
    for i in range(B):
        points, labels = show_pic_choose_points(image)
        input_points, input_labels = np.array(points.copy()), np.array(labels.copy())
        Batch_points.append(input_points)
        Batch_labels.append(input_labels)
    Batch_points, Batch_labels = np.array(Batch_points), np.array(Batch_labels)  # (B,n,2) (B,n)
    points_torch = torch.as_tensor(Batch_points, dtype=torch.float, device=sam.device)
    labels_torch = torch.as_tensor(Batch_labels, dtype=torch.float, device=sam.device)

    masks, scores, _ = predictor.predict_torch(  # 这个函数也不一样  # masks(BxCxHxW)
        point_coords=points_torch,
        point_labels=labels_torch,
        boxes=None,    # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
        multimask_output=multimask_output_bool,  # ♥
    )
    # =============显示预测
    for i, (mask_i, scores_i, points_i, labels_i) in enumerate(zip(masks, scores, Batch_points, Batch_labels)):
        # mask_i (C,H,W) scores_i(C) points_i (n,2)
        for j in range(mask_i.shape[0]):
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            show_mask(mask_i[0].cpu().numpy(), plt.gca())
            show_points(points_i, labels_i, plt.gca())
            plt.title(f"Points({i}):Mask {j + 1}, Score: {scores_i[j].cpu().numpy():.3f}", fontsize=18)
            plt.axis('off')
            plt.show()

显示图片有点大,我只显示重要部分
在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

可运行的全部代码

将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。


全部代码需要自定义的地方:

multimask_output_bool = True # 是否需要多个输出

B = 2 # 想输入多少batch个prompt点的提示

path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面


注意:

batch=1的输入和batch>1的输入的区别就是输入形式不同,因此会使用到predictor.predict 和 predictor.predict_torch

先看看predictor.py中的函数参数定义

    def predict(  # 多个点可以使用,单个box可以使用
        self,
        point_coords: Optional[np.ndarray] = None,  # np.array([[10, 20], [30, 40], [40, 50]]) (3, 2)
        point_labels: Optional[np.ndarray] = None,  # np.array([1, 0, 1])  (3,)
        box: Optional[np.ndarray] = None,  # box = np.array([10, 20, 50, 60])  (4,)
        mask_input: Optional[np.ndarray] = None,  # 1xHxW
        multimask_output: bool = True,
        return_logits: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        使用当前设置的图像(1张),对给定的输入提示进行掩码预测。
        参数:
          point_coords (np.ndarray or None): N个点提示的数组,每个点以像素(X,Y)表示 [[102,204],[100,200]...]
          point_labels (np.ndarray or None): 长度为N的点提示的标签数组。1表示前景点,0表示背景点。[1, 0, ...]
          box (np.ndarray or None): 给定的框提示的长度为4的数组,以XYXY格式表示。[[x, y, x+w, y+h]]只支持一个框
          mask_input (np.ndarray): 低分辨率的掩码输入到模型中,通常来自先前的预测迭代。其形状为1xHxW,其中对于SAM,H=W=256。
          multimask_output (bool): 如果为True,则模型将返回三个掩码。对于模糊的输入提示(例如单击),这通常会产生比单个预测更好的掩码。
                                  如果只需要一个单独的掩码,则可以使用模型预测的质量分数来选择最佳的掩码。
                                  对于非模糊的提示,例如多个输入提示,multimask_output=False可以给出更好的结果
          return_logits (bool): 如果为True,则返回未阈值化的掩码logits,而不是二进制掩码。
                                Logits是指模型输出的未经过激活函数(如sigmoid或softmax)处理的原始数值。

        Returns:
          (np.ndarray): 以CxHxW格式表示的输出掩码,其中C是掩码的数量,(H,W)是原始图像的尺寸。
          (np.ndarray): 长度为C的数组,包含每个掩码的质量预测。
          (np.ndarray): 形状为CxHxW的数组,其中C是掩码的数量,H=W=256。这些低分辨率的logits可以作为下一次迭代的掩码输入
        """

 @torch.no_grad()
    def predict_torch(
        self,
        point_coords: Optional[torch.Tensor],
        point_labels: Optional[torch.Tensor],
        boxes: Optional[torch.Tensor] = None,
        mask_input: Optional[torch.Tensor] = None,
        multimask_output: bool = True,
        return_logits: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        对于给定的输入提示,使用当前设置的图像预测掩码。
        输入提示是♥批量处理♥的torch张量,预期已经使用ResizeLongestSide方法转换为输入帧的尺寸。

        Arguments:
          point_coords (torch.Tensor or None): 大小为BxNx2的点提示数组,每个点以像素(X,Y)表示。 # torch.Size([1, n, 2])
          point_labels (torch.Tensor or None): 大小为BxN的点提示标签数组。1表示前景点,0表示背景点。  # torch.Size([1, n])
          boxes (torch.Tensor or None): 大小为Bx4的框提示数组,以XYXY格式表示。  # torch.Size([1, 1, 4])  
          mask_input (torch.Tensor):  # torch.Size([Bx1xHxW])
          multimask_output (bool): 如果为True,则模型将返回三个掩码。对于模糊的输入提示(例如单击),这通常会产生比单个预测更好的掩码。
                                  如果只需要一个单独的掩码,则可以使用模型预测的质量分数来选择最佳的掩码。
                                  对于非模糊的提示,例如多个输入提示,multimask_output=False可以给出更好的结果。

        Returns:
          (torch.Tensor): 以BxCxHxW格式表示的输出掩码,其中B是批次大小,C是掩码的数量,(H,W)是原始图像的尺寸。
          (torch.Tensor): 形状为BxC的数组,包含每个掩码的质量预测。
          (torch.Tensor): 形状为BxCxHxW的数组,其中B是批次大小,C是掩码的数量,H=W=256。这些低分辨率的logits可以作为下一次迭代的掩码输入。
        """

全部代码

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
'''
步骤1: 查看测试图片显示前景和背景的标记点
      鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''
# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
    if event.button == 1:  # 鼠标左键
        label = 1
    elif event.button == 3:  # 鼠标右键
        label = 0
    else:
        return
    x = event.xdata
    y = event.ydata
    # 将点击的点的坐标和类型保存到列表中
    points_coords.append([x, y])
    labels_coords.append(label)
    # 在图像上绘制点击的点
    plt.title("choose pos_points and neg_points")
    plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
    plt.draw()
    # 打印保存的点坐标和类型
    print(f"坐标:({x}, {y}),类型:{label}")



def show_pic_choose_points(img):
    global points_coords
    global labels_coords
    # 创建一个新的图像窗口并显示图像
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.set_title("choose pos_points and neg_points")
    # 定义存储点坐标和类型的列表
    points_coords = []
    labels_coords = []
    # 绑定鼠标点击事件处理函数
    cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
    # 显示图像和等待用户点击点
    plt.show()
    return points_coords, labels_coords


def show_points(coords, labels, ax):
    # 筛选出前景目标标记点
    pos_points = coords[labels == 1]
    # 筛选出背景目标标记点
    neg_points = coords[labels == 0]
    # x-->pos_points[:, 0] y-->pos_points[:, 1]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='red', marker='.', s=200)  # 前景的标记点显示
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='blue', marker='.', s=200)  # 背景的标记点显示


def show_mask(mask, ax, random_color=False):
    if random_color:    # 掩膜颜色是否随机决定
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

multimask_output_bool = True # ♥


# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h)  # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)

# =============选点
B = 2   # ♥
# ============预测
# 单张points图片只需要predict 多张points则需要predict_torch
if B == 1:
    # 选点
    points, labels = show_pic_choose_points(image)
    input_points, input_labels = np.array(points.copy()), np.array(labels.copy())  # input_points(n,2) input_labels(n,)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,     # 形式 [[x,y], [x,y],...] (n,2)
        point_labels=input_labels,      # 形式 [0, 1, ...] (n, )
        multimask_output=multimask_output_bool,  # ♥
    )
    # =============显示预测
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_points, input_labels, plt.gca())
        plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()
else:
    Batch_points = []
    Batch_labels = []
    for i in range(B):
        points, labels = show_pic_choose_points(image)
        input_points, input_labels = np.array(points.copy()), np.array(labels.copy())
        Batch_points.append(input_points)
        Batch_labels.append(input_labels)
    Batch_points, Batch_labels = np.array(Batch_points), np.array(Batch_labels)  # (B,n,2) (B,n)
    points_torch = torch.as_tensor(Batch_points, dtype=torch.float, device=sam.device)
    labels_torch = torch.as_tensor(Batch_labels, dtype=torch.float, device=sam.device)

    masks, scores, _ = predictor.predict_torch(  # 这个函数也不一样  # masks(BxCxHxW)
        point_coords=points_torch,
        point_labels=labels_torch,
        boxes=None,    # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
        multimask_output=multimask_output_bool,  # ♥
    )
    # =============显示预测
    for i, (mask_i, scores_i, points_i, labels_i) in enumerate(zip(masks, scores, Batch_points, Batch_labels)):
        # mask_i (C,H,W) scores_i(C) points_i (n,2)
        for j in range(mask_i.shape[0]):
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            show_mask(mask_i[0].cpu().numpy(), plt.gca())
            show_points(points_i, labels_i, plt.gca())
            plt.title(f"Points({i}):Mask {j + 1}, Score: {scores_i[j].cpu().numpy():.3f}", fontsize=18)
            plt.axis('off')
            plt.show()

总结

总体来说,如果想全面了解sam模型代码,最重要的还是把SAM的predictor.py文件好好看看,捋清楚维度的变化,什么时候是numpy,什么时候是tensor,输入形式是什么?输出形式又是什么?这里面大有学问呀!

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
在Windows系统中,您可以按照以下步骤使用YOLOv8训练自己的数据集: 1. 确保您的系统已经安装了Python和PyTorch,并且已经正确配置了CUDA。 2. 下载YOLOv8代码库,并将其解压到您的工作目录中。 3. 根据您的数据集,将训练图像和标注文件存储在适当的文件夹中。标注文件可以是txt文件,其中每行表示一个对象的标注信息,包括类别和边界框的坐标。确保标签格式与YOLOv8要求的格式一致。 4. 打开配置文件default.yaml,根据您的数据集和训练需求,进行必要的配置更改。您可以设置训练集和验证集的路径、类别数量、批处理大小等参数。 5. 打开命令提示符或Anaconda Prompt,进入YOLOv8代码库所在的目录。 6. 使用以下命令开始训练模型: ``` python train.py --data data.yaml --cfg models/yolov8.yaml --weights '' --batch-size 16 ``` 其中,--data用于指定数据集的配置文件,--cfg用于指定模型的配置文件,--weights用于指定预训练的权重文件,可以选择从头开始训练,--batch-size用于指定批处理大小。 7. 等待训练完成,训练过程中会在模型文件夹中保存权重文件。 请注意,以上步骤仅为大致描述,具体细节可能因YOLOv8的版本和您的数据集而有所不同。建议您参考YOLOv8的官方文档和示例代码,以获得更详细的指导和说明。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [【YOLO】YOLOv8训练自定义数据集(4种方式)](https://blog.csdn.net/weixin_42166222/article/details/129391260)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* [windows10 yolov3训练自己的数据.docx](https://download.csdn.net/download/qq_36614037/12682966)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值