prompt boxes推理
写在前面
网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。
还可以参考本人其他的交互式选取prompt提示的文章
参考宝藏博主的blog:
【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法
图片来源:
https://www.photophoto.cn/sucai/20184408.html
效果+代码
1. 单个prompt box进行预测
鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始推理
1.1 效果
如果multiple_box_bool = False 那么意味着只能选一个目标框,目标框选定后,想要再选择box,是没有反应的,除非multiple_box_bool = True, 可以看看后面的代码
选定的目标框坐标:
单个输出:
多个输出:
2. 多个prompt boxes进行预测
这个多个prompt boxes, 可以是一张图片上多个目标框,
或者Batch个输入,但是batch中每个子部分只能有一个目标框
鼠标左键点击一下矩形框起点, 再点击一下矩形框终点,自动形成矩形框,关闭图片开始推理
2.1 效果
如果multiple_box_bool = False 那么意味着只能选一个目标框,目标框选定后,想要再选择box,是没有反应的,除非multiple_box_bool = True, 可以看看后面的代码
输出:
可运行的全部代码
将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。
全部代码需要自定义的地方:
multimask_output_bool = True # 是否需要多个输出
multiple_box_bool = True # 是否要选择多个目标框boxes
path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面
注意:
multiple_box_bool = True /False 的区别就是输入形式不同,因此会使用到predictor.predict 和 predictor.predict_torch
完整代码
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
# 定义回调函数,用于处理鼠标点击事件
def onclick_box(event):
if event.button == 1: # 鼠标左键
if len(points_box) == 1:
if flag == False:
return
clicked_points.append([event.xdata, event.ydata])
# 在图像上绘制点击的点
if flag:
plt.title("choose some boxes!")
else:
plt.title("choose a box!")
plt.scatter(event.xdata, event.ydata, marker='.', s=200, c='red')
plt.draw()
if len(clicked_points) == 2: #当点击了两个点
x1, y1 = clicked_points[0]
x2, y2 = clicked_points[1]
points_box.append([x1, y1, x2, y2])
# 在图像上绘制矩形
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
plt.gca().add_patch(rect)
plt.draw()
# 打印保存的矩形坐标和类型
print(f"矩形坐标:({x1}, {y1}, {x2}, {y2})")
clicked_points.clear() # 清空已点击的点列表
else:
return
def show_pic_choose_boxes(img, multiple_box_bool):
global points_box
global clicked_points
global flag
# 创建一个新的图像窗口并显示图像
fig, ax = plt.subplots()
ax.imshow(img)
flag = multiple_box_bool
# 设置标题
if flag:
ax.set_title("choose some boxes!")
else:
ax.set_title("choose a box!")
# 定义存储点坐标和类型的列表
points_box = []
clicked_points = []
boxes = []
# 绑定鼠标点击事件处理函数
cid = fig.canvas.mpl_connect('button_press_event', onclick_box)
# 显示图像和等待用户点击点
plt.show()
boxes = points_box.copy()
return boxes
def show_box(boxes, ax):
# 画出标定框 x0 y0是起始坐标
for box in boxes:
x0, y0 = box[0], box[1]
# w h 是框的尺寸
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_mask(mask, ax, random_color=False):
if random_color: # 掩膜颜色是否随机决定
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
multimask_output_bool = True # ♥
# =============选框
multiple_box_bool = True # ♥
boxes = show_pic_choose_boxes(image, multiple_box_bool)
input_boxes = np.array(boxes.copy())
print(input_boxes)
# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h) # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
# ==========如果是多个目标狂,则需要对box_points进行转换
# ♥ 这个多个目标框的意思其实是 一张图片上多个目标框,或者多张图片上只有一个目标框
if multiple_box_bool:
boxes_torch = torch.as_tensor(input_boxes, dtype=torch.float, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_torch, image.shape[:2]) # (n,4) [[ 20. 40. 100. 120.]]
# ============多个目标框预测
masks, _, _ = predictor.predict_torch( # 这个函数也不一样
point_coords=None,
point_labels=None,
boxes=transformed_boxes, # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
multimask_output=False, # ♥
)
# ==========多个目标狂显示预测
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
show_box(boxes_torch.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
else:
# ============单个预测
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes, # 输入格式为(1,4) 只能是一个box的坐标
multimask_output=multimask_output_bool, # ♥
)
# =============显示预测
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_box(input_boxes, plt.gca())
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
# 总结
### 总体来说,如果想全面了解sam模型代码,最重要的还是把SAM的predictor.py文件好好看看,捋清楚维度的变化,什么时候是numpy,什么时候是tensor,输入形式是什么?输出形式又是什么?这里面大有学问呀!