【SegmentAnything实战——推理3】交互式选取prompt boxes和prompt points完成SAM推理自己的数据集

写在前面

网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。

还可以参考本人其他的交互式选取prompt提示的文章

  1. 【SegmentAnything实战——推理1】交互式选取prompt points完成SAM推理自己的数据集
  2. 【SegmentAnything实战——推理2】交互式选取prompt boxes完成SAM推理自己的数据集

参考宝藏博主的blog:

【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法


图片来源:

https://www.photophoto.cn/sucai/20184408.html


效果+代码

1. 选目标框(1个)+points进行预测

  鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始选点,鼠标左键选择前景点,鼠标右键选择背景点。

1.1 效果

如果multiple_box_bool = False 那么意味着只能选一个目标框,box和points联合预测的话,选择的目标框只能为一个,因此只能为False,目标框选定后,想要再选择box,是没有反应的,可以看看后面的代码

![在这里插入图片描述](https://img-blog.csdnimg.cn/e926bd58d932474c8f69bf449a8d6ebc.png
在这里插入图片描述


选定的目标框坐标:

在这里插入图片描述
选定的点坐标: 类型0为背景点,1为前景点

在这里插入图片描述


单个输出:

在这里插入图片描述




可运行的全部代码

将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。

全部代码需要自定义的地方:

multimask_output_bool = True # 是否需要多个输出

multiple_box_bool = True # 是否要选择多个目标框boxes

path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面


完整代码

"""
box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。
"""

import cv2
import matplotlib.pyplot as plt
import numpy as np
from segment_anything import SamPredictor, sam_model_registry


'''
步骤1: 查看测试图片显示前景和背景的标记点
      鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''


# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
    if event.button == 1:  # 鼠标左键
        label = 1
    elif event.button == 3:  # 鼠标右键
        label = 0
    else:
        return
    x = event.xdata
    y = event.ydata
    # 将点击的点的坐标和类型保存到列表中
    points_coords.append([x, y])
    labels_coords.append(label)
    # 在图像上绘制点击的点
    plt.title("choose pos_points and neg_points")
    plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
    plt.draw()
    # 打印保存的点坐标和类型
    print(f"坐标:({x}, {y}),类型:{label}")



def show_pic_choose_points(img):
    global points_coords
    global labels_coords
    # 创建一个新的图像窗口并显示图像
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.set_title("choose pos_points and neg_points")
    # 定义存储点坐标和类型的列表
    points_coords = []
    labels_coords = []
    # 绑定鼠标点击事件处理函数
    cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
    # 显示图像和等待用户点击点
    plt.show()
    return points_coords, labels_coords


def show_points(coords, labels, ax):
    # 筛选出前景目标标记点
    pos_points = coords[labels == 1]
    # 筛选出背景目标标记点
    neg_points = coords[labels == 0]
    # x-->pos_points[:, 0] y-->pos_points[:, 1]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='red', marker='.', s=200)  # 前景的标记点显示
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='blue', marker='.', s=200)  # 背景的标记点显示


def show_mask(mask, ax, random_color=False):
    if random_color:    # 掩膜颜色是否随机决定
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


# 定义回调函数,用于处理鼠标点击事件
def onclick_box(event):
    if event.button == 1:  # 鼠标左键
        if len(points_box) == 1:
            if flag == False:
                return
        clicked_points.append([event.xdata, event.ydata])
        # 在图像上绘制点击的点
        plt.title("choose a box!")
        plt.scatter(event.xdata, event.ydata, marker='.', s=200, c='red')
        plt.draw()

        if len(clicked_points) == 2:  #当点击了两个点
            x1, y1 = clicked_points[0]
            x2, y2 = clicked_points[1]
            points_box.append([x1, y1, x2, y2])

            # 在图像上绘制矩形
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
            plt.gca().add_patch(rect)
            plt.draw()
            # 打印保存的矩形坐标和类型
            print(f"矩形坐标:({x1}, {y1}, {x2}, {y2})")
            clicked_points.clear()  # 清空已点击的点列表

    else:
        return



def show_pic_choose_boxes(img, multiple_box_bool):
    global points_box
    global clicked_points
    global flag
    # 创建一个新的图像窗口并显示图像
    fig, ax = plt.subplots()
    ax.imshow(img)
    # 设置标题
    ax.set_title("choose a box!")
    # 定义存储点坐标和类型的列表
    points_box = []
    clicked_points = []
    flag = multiple_box_bool
    boxes = []
    # 绑定鼠标点击事件处理函数
    cid = fig.canvas.mpl_connect('button_press_event', onclick_box)
    # 显示图像和等待用户点击点
    plt.show()
    boxes = points_box.copy()
    return boxes


def show_box(boxes, ax):
    # 画出标定框 x0 y0是起始坐标
    for box in boxes:
        x0, y0 = box[0], box[1]
        # w h 是框的尺寸
        w, h = box[2] - box[0], box[3] - box[1]
        ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))



path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 多个输出?
multimask_output_bool = False # ♥

# =============选框
multiple_box_bool = False  # 只能False
boxes = show_pic_choose_boxes(image, multiple_box_bool)
input_boxes = np.array(boxes.copy())
print("=======选框")
print(input_boxes)
# =============选点
points_coords, labels_coords = show_pic_choose_points(image)
input_points, input_labels = np.array(points_coords.copy()), np.array(labels_coords.copy())  # input_points(n,2) input_labels(n,)
print("=======选点")
print(input_points, input_labels)
# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h)  # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
# ============预测
masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    box=input_boxes,
    multimask_output=multimask_output_bool,
)
# =============显示预测
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_points, input_labels, plt.gca())
    show_box(input_boxes, plt.gca())
    plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()


总结

总体来说,如果想全面了解sam模型代码,最重要的还是把SAM的predictor.py文件好好看看,捋清楚维度的变化,什么时候是numpy,什么时候是tensor,输入形式是什么?输出形式又是什么?这里面大有学问呀again~!

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值