【SegmentAnything实战——推理2】交互式选取prompt boxes完成SAM推理自己的数据集

写在前面

网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。

还可以参考本人其他的交互式选取prompt提示的文章

  1. 【SegmentAnything实战——推理1】交互式选取prompt points完成SAM推理自己的数据集
  2. 【SegmentAnything实战——推理3】交互式选取prompt boxes和prompt points完成SAM推理自己的数据集

参考宝藏博主的blog:

【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法


图片来源:

https://www.photophoto.cn/sucai/20184408.html


效果+代码

1. 单个prompt box进行预测

  鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始推理

1.1 效果

如果multiple_box_bool = False 那么意味着只能选一个目标框,目标框选定后,想要再选择box,是没有反应的,除非multiple_box_bool = True, 可以看看后面的代码

在这里插入图片描述


选定的目标框坐标:

在这里插入图片描述


单个输出:

在这里插入图片描述


多个输出:

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述


2. 多个prompt boxes进行预测

  这个多个prompt boxes, 可以是一张图片上多个目标框,
  或者Batch个输入,但是batch中每个子部分只能有一个目标框
  
  鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始推理

2.1 效果

如果multiple_box_bool = False 那么意味着只能选一个目标框,目标框选定后,想要再选择box,是没有反应的,除非multiple_box_bool = True, 可以看看后面的代码
在这里插入图片描述
在这里插入图片描述


输出:

在这里插入图片描述


可运行的全部代码

将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。

全部代码需要自定义的地方:

multimask_output_bool = True # 是否需要多个输出

multiple_box_bool = True # 是否要选择多个目标框boxes

path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面

注意:

multiple_box_bool = True /False 的区别就是输入形式不同,因此会使用到predictor.predict 和 predictor.predict_torch

完整代码

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry

# 定义回调函数,用于处理鼠标点击事件
def onclick_box(event):
    if event.button == 1:  # 鼠标左键
        if len(points_box) == 1:
            if flag == False:
                return
        clicked_points.append([event.xdata, event.ydata])
        # 在图像上绘制点击的点
        if flag:
            plt.title("choose some boxes!")
        else:
            plt.title("choose a box!")
        plt.scatter(event.xdata, event.ydata, marker='.', s=200, c='red')
        plt.draw()

        if len(clicked_points) == 2:  #当点击了两个点
            x1, y1 = clicked_points[0]
            x2, y2 = clicked_points[1]
            points_box.append([x1, y1, x2, y2])

            # 在图像上绘制矩形
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
            plt.gca().add_patch(rect)
            plt.draw()
            # 打印保存的矩形坐标和类型
            print(f"矩形坐标:({x1}, {y1}, {x2}, {y2})")
            clicked_points.clear()  # 清空已点击的点列表

    else:
        return


def show_pic_choose_boxes(img, multiple_box_bool):
    global points_box
    global clicked_points
    global flag
    # 创建一个新的图像窗口并显示图像
    fig, ax = plt.subplots()
    ax.imshow(img)
    flag = multiple_box_bool
    # 设置标题
    if flag:
        ax.set_title("choose some boxes!")
    else:
        ax.set_title("choose a box!")
    # 定义存储点坐标和类型的列表
    points_box = []
    clicked_points = []
    boxes = []
    # 绑定鼠标点击事件处理函数
    cid = fig.canvas.mpl_connect('button_press_event', onclick_box)
    # 显示图像和等待用户点击点
    plt.show()
    boxes = points_box.copy()
    return boxes


def show_box(boxes, ax):
    # 画出标定框 x0 y0是起始坐标
    for box in boxes:
        x0, y0 = box[0], box[1]
        # w h 是框的尺寸
        w, h = box[2] - box[0], box[3] - box[1]
        ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


def show_mask(mask, ax, random_color=False):
    if random_color:    # 掩膜颜色是否随机决定
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

multimask_output_bool = True # ♥

# =============选框
multiple_box_bool = True  # ♥
boxes = show_pic_choose_boxes(image, multiple_box_bool)
input_boxes = np.array(boxes.copy())
print(input_boxes)
# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h)  # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)

# ==========如果是多个目标狂,则需要对box_points进行转换
# ♥ 这个多个目标框的意思其实是 一张图片上多个目标框,或者多张图片上只有一个目标框
if multiple_box_bool:
    boxes_torch = torch.as_tensor(input_boxes, dtype=torch.float, device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_torch, image.shape[:2])  # (n,4) [[ 20.  40. 100. 120.]]

    # ============多个目标框预测
    masks, _, _ = predictor.predict_torch(  # 这个函数也不一样
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,    # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
        multimask_output=False,  # ♥
    )
    # ==========多个目标狂显示预测
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
        show_box(boxes_torch.cpu().numpy(), plt.gca())
    plt.axis('off')
    plt.show()

else:
    # ============单个预测
    masks, scores, logits = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_boxes,  # 输入格式为(1,4) 只能是一个box的坐标
        multimask_output=multimask_output_bool,  # ♥
    )
    # =============显示预测
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_box(input_boxes, plt.gca())
        plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()



# 总结

###   总体来说,如果想全面了解sam模型代码,最重要的还是把SAM的predictor.py文件好好看看,捋清楚维度的变化,什么时候是numpy,什么时候是tensor,输入形式是什么?输出形式又是什么?这里面大有学问呀!
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值