prompt boxes+points联合推理
写在前面
网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。
还可以参考本人其他的交互式选取prompt提示的文章
参考宝藏博主的blog:
【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法
图片来源:
https://www.photophoto.cn/sucai/20184408.html
效果+代码
1. 选目标框(1个)+points进行预测
鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始选点,鼠标左键选择前景点,鼠标右键选择背景点。
1.1 效果
如果multiple_box_bool = False 那么意味着只能选一个目标框,box和points联合预测的话,选择的目标框只能为一个,因此只能为False,目标框选定后,想要再选择box,是没有反应的,可以看看后面的代码
选定的目标框坐标:
选定的点坐标: 类型0为背景点,1为前景点
单个输出:
可运行的全部代码
将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。
全部代码需要自定义的地方:
multimask_output_bool = True # 是否需要多个输出
multiple_box_bool = True # 是否要选择多个目标框boxes
path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面
完整代码
"""
box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
'''
步骤1: 查看测试图片显示前景和背景的标记点
鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''
# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
if event.button == 1: # 鼠标左键
label = 1
elif event.button == 3: # 鼠标右键
label = 0
else:
return
x = event.xdata
y = event.ydata
# 将点击的点的坐标和类型保存到列表中
points_coords.append([x, y])
labels_coords.append(label)
# 在图像上绘制点击的点
plt.title("choose pos_points and neg_points")
plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
plt.draw()
# 打印保存的点坐标和类型
print(f"坐标:({x}, {y}),类型:{label}")
def show_pic_choose_points(img):
global points_coords
global labels_coords
# 创建一个新的图像窗口并显示图像
fig, ax = plt.subplots()
ax.imshow(img)
ax.set_title("choose pos_points and neg_points")
# 定义存储点坐标和类型的列表
points_coords = []
labels_coords = []
# 绑定鼠标点击事件处理函数
cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
# 显示图像和等待用户点击点
plt.show()
return points_coords, labels_coords
def show_points(coords, labels, ax):
# 筛选出前景目标标记点
pos_points = coords[labels == 1]
# 筛选出背景目标标记点
neg_points = coords[labels == 0]
# x-->pos_points[:, 0] y-->pos_points[:, 1]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='red', marker='.', s=200) # 前景的标记点显示
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='blue', marker='.', s=200) # 背景的标记点显示
def show_mask(mask, ax, random_color=False):
if random_color: # 掩膜颜色是否随机决定
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
# 定义回调函数,用于处理鼠标点击事件
def onclick_box(event):
if event.button == 1: # 鼠标左键
if len(points_box) == 1:
if flag == False:
return
clicked_points.append([event.xdata, event.ydata])
# 在图像上绘制点击的点
plt.title("choose a box!")
plt.scatter(event.xdata, event.ydata, marker='.', s=200, c='red')
plt.draw()
if len(clicked_points) == 2: #当点击了两个点
x1, y1 = clicked_points[0]
x2, y2 = clicked_points[1]
points_box.append([x1, y1, x2, y2])
# 在图像上绘制矩形
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
plt.gca().add_patch(rect)
plt.draw()
# 打印保存的矩形坐标和类型
print(f"矩形坐标:({x1}, {y1}, {x2}, {y2})")
clicked_points.clear() # 清空已点击的点列表
else:
return
def show_pic_choose_boxes(img, multiple_box_bool):
global points_box
global clicked_points
global flag
# 创建一个新的图像窗口并显示图像
fig, ax = plt.subplots()
ax.imshow(img)
# 设置标题
ax.set_title("choose a box!")
# 定义存储点坐标和类型的列表
points_box = []
clicked_points = []
flag = multiple_box_bool
boxes = []
# 绑定鼠标点击事件处理函数
cid = fig.canvas.mpl_connect('button_press_event', onclick_box)
# 显示图像和等待用户点击点
plt.show()
boxes = points_box.copy()
return boxes
def show_box(boxes, ax):
# 画出标定框 x0 y0是起始坐标
for box in boxes:
x0, y0 = box[0], box[1]
# w h 是框的尺寸
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 多个输出?
multimask_output_bool = False # ♥
# =============选框
multiple_box_bool = False # 只能False
boxes = show_pic_choose_boxes(image, multiple_box_bool)
input_boxes = np.array(boxes.copy())
print("=======选框")
print(input_boxes)
# =============选点
points_coords, labels_coords = show_pic_choose_points(image)
input_points, input_labels = np.array(points_coords.copy()), np.array(labels_coords.copy()) # input_points(n,2) input_labels(n,)
print("=======选点")
print(input_points, input_labels)
# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h) # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
# ============预测
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
box=input_boxes,
multimask_output=multimask_output_bool,
)
# =============显示预测
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_points, input_labels, plt.gca())
show_box(input_boxes, plt.gca())
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()