prompt points推理
写在前面
网上许多推理代码都是自己输入坐标,比较繁琐,我设计了一个可以交互式选取prompt的程序代码,方面可以方便进行推理。
还可以参考本人其他的交互式选取prompt提示的文章
【SegmentAnything实战——推理2】交互式选取prompt boxes完成SAM推理自己的数据集
【SegmentAnything实战——推理3】交互式选取prompt boxes和prompt points完成SAM推理自己的数据集
参考宝藏博主的blog:
【图像分割】【深度学习】Windows10下SAM(Segment Anything)官方代码Pytorch实现与源码讲解
本人最近也在学习SAM模型,记录相关知识点,分享学习中遇到的问题已经解决的方法
图片来源:
https://www.photophoto.cn/sucai/20184408.html
效果+代码
1. 单张图片上多个prompt points进行预测
鼠标左键点击为前景点(红色), 右键点击为背景点(蓝色),关闭图片开始推理
1.1 选点
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
'''
步骤1: 查看测试图片显示前景和背景的标记点
鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''
# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
if event.button == 1: # 鼠标左键
label = 1
elif event.button == 3: # 鼠标右键
label = 0
else:
return
x = event.xdata
y = event.ydata
# 将点击的点的坐标和类型保存到列表中
points_coords.append([x, y])
labels_coords.append(label)
# 在图像上绘制点击的点
plt.title("choose pos_points and neg_points")
plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
plt.draw()
# 打印保存的点坐标和类型
print(f"坐标:({x}, {y}),类型:{label}")
def show_pic_choose_points(img):
global points_coords
global labels_coords
# 创建一个新的图像窗口并显示图像
fig, ax = plt.subplots()
ax.imshow(img)
ax.set_title("choose pos_points and neg_points")
# 定义存储点坐标和类型的列表
points_coords = []
labels_coords = []
# 绑定鼠标点击事件处理函数
cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
# 显示图像和等待用户点击点
plt.show()
return points_coords, labels_coords
1.2 multimask_output = False时,只输出一个mask
1.3 multimask_output = True时,输出三个mask
# =============选点
B = 1 # ♥
# ============预测
# 单张points图片只需要predict 多张points则需要predict_torch
if B == 1:
# 选点
points, labels = show_pic_choose_points(image)
input_points, input_labels = np.array(points.copy()), np.array(labels.copy()) # input_points(n,2) input_labels(n,)
masks, scores, logits = predictor.predict(
point_coords=input_points, # 形式 [[x,y], [x,y],...] (n,2)
point_labels=input_labels, # 形式 [0, 1, ...] (n, )
multimask_output=multimask_output_bool, # ♥
)
# =============显示预测
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_points, input_labels, plt.gca())
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
2. 多张图片上多个prompt points进行预测
要注意,每次选择点时,前景点和背景点的总点数必须每张图片是一样的
2.1 选点(以选两张点为例)
2.2 预测
else:
Batch_points = []
Batch_labels = []
for i in range(B):
points, labels = show_pic_choose_points(image)
input_points, input_labels = np.array(points.copy()), np.array(labels.copy())
Batch_points.append(input_points)
Batch_labels.append(input_labels)
Batch_points, Batch_labels = np.array(Batch_points), np.array(Batch_labels) # (B,n,2) (B,n)
points_torch = torch.as_tensor(Batch_points, dtype=torch.float, device=sam.device)
labels_torch = torch.as_tensor(Batch_labels, dtype=torch.float, device=sam.device)
masks, scores, _ = predictor.predict_torch( # 这个函数也不一样 # masks(BxCxHxW)
point_coords=points_torch,
point_labels=labels_torch,
boxes=None, # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
multimask_output=multimask_output_bool, # ♥
)
# =============显示预测
for i, (mask_i, scores_i, points_i, labels_i) in enumerate(zip(masks, scores, Batch_points, Batch_labels)):
# mask_i (C,H,W) scores_i(C) points_i (n,2)
for j in range(mask_i.shape[0]):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask_i[0].cpu().numpy(), plt.gca())
show_points(points_i, labels_i, plt.gca())
plt.title(f"Points({i}):Mask {j + 1}, Score: {scores_i[j].cpu().numpy():.3f}", fontsize=18)
plt.axis('off')
plt.show()
显示图片有点大,我只显示重要部分
可运行的全部代码
将.py代码放入segment-anything-main的根目录下就行,然后创建一个testset文件夹放入数据集图片,创建checkpoints文件夹存放SAM模型权重,我是三个权重都下载了,都运行了一下,速度还是比较快的。
全部代码需要自定义的地方:
multimask_output_bool = True # 是否需要多个输出
B = 2 # 想输入多少batch个prompt点的提示
path = “./testset/R.jpg” # 图片路径,最好整理起来放在testset文件夹下面
注意:
batch=1的输入和batch>1的输入的区别就是输入形式不同,因此会使用到predictor.predict 和 predictor.predict_torch
先看看predictor.py中的函数参数定义
def predict( # 多个点可以使用,单个box可以使用
self,
point_coords: Optional[np.ndarray] = None, # np.array([[10, 20], [30, 40], [40, 50]]) (3, 2)
point_labels: Optional[np.ndarray] = None, # np.array([1, 0, 1]) (3,)
box: Optional[np.ndarray] = None, # box = np.array([10, 20, 50, 60]) (4,)
mask_input: Optional[np.ndarray] = None, # 1xHxW
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
使用当前设置的图像(1张),对给定的输入提示进行掩码预测。
参数:
point_coords (np.ndarray or None): N个点提示的数组,每个点以像素(X,Y)表示 [[102,204],[100,200]...]
point_labels (np.ndarray or None): 长度为N的点提示的标签数组。1表示前景点,0表示背景点。[1, 0, ...]
box (np.ndarray or None): 给定的框提示的长度为4的数组,以XYXY格式表示。[[x, y, x+w, y+h]]只支持一个框
mask_input (np.ndarray): 低分辨率的掩码输入到模型中,通常来自先前的预测迭代。其形状为1xHxW,其中对于SAM,H=W=256。
multimask_output (bool): 如果为True,则模型将返回三个掩码。对于模糊的输入提示(例如单击),这通常会产生比单个预测更好的掩码。
如果只需要一个单独的掩码,则可以使用模型预测的质量分数来选择最佳的掩码。
对于非模糊的提示,例如多个输入提示,multimask_output=False可以给出更好的结果
return_logits (bool): 如果为True,则返回未阈值化的掩码logits,而不是二进制掩码。
Logits是指模型输出的未经过激活函数(如sigmoid或softmax)处理的原始数值。
Returns:
(np.ndarray): 以CxHxW格式表示的输出掩码,其中C是掩码的数量,(H,W)是原始图像的尺寸。
(np.ndarray): 长度为C的数组,包含每个掩码的质量预测。
(np.ndarray): 形状为CxHxW的数组,其中C是掩码的数量,H=W=256。这些低分辨率的logits可以作为下一次迭代的掩码输入
"""
@torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
对于给定的输入提示,使用当前设置的图像预测掩码。
输入提示是♥批量处理♥的torch张量,预期已经使用ResizeLongestSide方法转换为输入帧的尺寸。
Arguments:
point_coords (torch.Tensor or None): 大小为BxNx2的点提示数组,每个点以像素(X,Y)表示。 # torch.Size([1, n, 2])
point_labels (torch.Tensor or None): 大小为BxN的点提示标签数组。1表示前景点,0表示背景点。 # torch.Size([1, n])
boxes (torch.Tensor or None): 大小为Bx4的框提示数组,以XYXY格式表示。 # torch.Size([1, 1, 4])
mask_input (torch.Tensor): # torch.Size([Bx1xHxW])
multimask_output (bool): 如果为True,则模型将返回三个掩码。对于模糊的输入提示(例如单击),这通常会产生比单个预测更好的掩码。
如果只需要一个单独的掩码,则可以使用模型预测的质量分数来选择最佳的掩码。
对于非模糊的提示,例如多个输入提示,multimask_output=False可以给出更好的结果。
Returns:
(torch.Tensor): 以BxCxHxW格式表示的输出掩码,其中B是批次大小,C是掩码的数量,(H,W)是原始图像的尺寸。
(torch.Tensor): 形状为BxC的数组,包含每个掩码的质量预测。
(torch.Tensor): 形状为BxCxHxW的数组,其中B是批次大小,C是掩码的数量,H=W=256。这些低分辨率的logits可以作为下一次迭代的掩码输入。
"""
全部代码
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
'''
步骤1: 查看测试图片显示前景和背景的标记点
鼠标左键为选前景点,鼠标右键为选背景点,关闭图像退出选点
'''
# 定义回调函数,用于处理鼠标点击事件
def onclick_points(event):
if event.button == 1: # 鼠标左键
label = 1
elif event.button == 3: # 鼠标右键
label = 0
else:
return
x = event.xdata
y = event.ydata
# 将点击的点的坐标和类型保存到列表中
points_coords.append([x, y])
labels_coords.append(label)
# 在图像上绘制点击的点
plt.title("choose pos_points and neg_points")
plt.scatter(x, y, marker='.', s=200, c='red' if int(label) == 1 else 'blue')
plt.draw()
# 打印保存的点坐标和类型
print(f"坐标:({x}, {y}),类型:{label}")
def show_pic_choose_points(img):
global points_coords
global labels_coords
# 创建一个新的图像窗口并显示图像
fig, ax = plt.subplots()
ax.imshow(img)
ax.set_title("choose pos_points and neg_points")
# 定义存储点坐标和类型的列表
points_coords = []
labels_coords = []
# 绑定鼠标点击事件处理函数
cid = fig.canvas.mpl_connect('button_press_event', onclick_points)
# 显示图像和等待用户点击点
plt.show()
return points_coords, labels_coords
def show_points(coords, labels, ax):
# 筛选出前景目标标记点
pos_points = coords[labels == 1]
# 筛选出背景目标标记点
neg_points = coords[labels == 0]
# x-->pos_points[:, 0] y-->pos_points[:, 1]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='red', marker='.', s=200) # 前景的标记点显示
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='blue', marker='.', s=200) # 背景的标记点显示
def show_mask(mask, ax, random_color=False):
if random_color: # 掩膜颜色是否随机决定
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
path = "./testset/R.jpg"
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
multimask_output_bool = True # ♥
# ============加载模型
sam_checkpoints_b = "./checkpoints/sam_vit_b_01ec64.pth"
sam_checkpoints_l = "./checkpoints/sam_vit_l_0b3195.pth"
sam_checkpoints_h = "./checkpoints/sam_vit_h_4b8939.pth"
# 模型类型
model_type_b = "vit_b"
model_type_l = "vit_l"
model_type_h = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type_h](sam_checkpoints_h) # ♥
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
# =============选点
B = 2 # ♥
# ============预测
# 单张points图片只需要predict 多张points则需要predict_torch
if B == 1:
# 选点
points, labels = show_pic_choose_points(image)
input_points, input_labels = np.array(points.copy()), np.array(labels.copy()) # input_points(n,2) input_labels(n,)
masks, scores, logits = predictor.predict(
point_coords=input_points, # 形式 [[x,y], [x,y],...] (n,2)
point_labels=input_labels, # 形式 [0, 1, ...] (n, )
multimask_output=multimask_output_bool, # ♥
)
# =============显示预测
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_points, input_labels, plt.gca())
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
else:
Batch_points = []
Batch_labels = []
for i in range(B):
points, labels = show_pic_choose_points(image)
input_points, input_labels = np.array(points.copy()), np.array(labels.copy())
Batch_points.append(input_points)
Batch_labels.append(input_labels)
Batch_points, Batch_labels = np.array(Batch_points), np.array(Batch_labels) # (B,n,2) (B,n)
points_torch = torch.as_tensor(Batch_points, dtype=torch.float, device=sam.device)
labels_torch = torch.as_tensor(Batch_labels, dtype=torch.float, device=sam.device)
masks, scores, _ = predictor.predict_torch( # 这个函数也不一样 # masks(BxCxHxW)
point_coords=points_torch,
point_labels=labels_torch,
boxes=None, # 形式 [[x1,y1,x2,y2], [x1,y1,x2,y2],...] (n,4)
multimask_output=multimask_output_bool, # ♥
)
# =============显示预测
for i, (mask_i, scores_i, points_i, labels_i) in enumerate(zip(masks, scores, Batch_points, Batch_labels)):
# mask_i (C,H,W) scores_i(C) points_i (n,2)
for j in range(mask_i.shape[0]):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask_i[0].cpu().numpy(), plt.gca())
show_points(points_i, labels_i, plt.gca())
plt.title(f"Points({i}):Mask {j + 1}, Score: {scores_i[j].cpu().numpy():.3f}", fontsize=18)
plt.axis('off')
plt.show()