一 SAM
https://github.com/facebookresearch/segment-anything
import torchvision
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
'''
显示mask掩膜
'''
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
'''
显示点
'''
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
'''
显示框
'''
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
'''
显示标注
'''
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:, :, i] = color_mask[i]
ax.imshow(np.dstack((img, m * 0.35)))
if __name__ == '__main__':
'''
步骤一 读取图像
'''
# 读取图像
image = cv2.imread(r'.\test_images\1.jpg')
# BGR转RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 显示图像
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('on')
plt.show()
'''
步骤二 准备SAM模型
'''
# 模型权重
sam_checkpoint = "sam_vit_b_01ec64.pth" # 改为已下载的模型的存放路径
# 制定设备
device = "cpu" # 默认是cuda,如果是用cpu的话就改为cpu
# 模型类型
model_type = "vit_b" # default默认代表的是vit_h模型,可将其改为自己下载的模型名称(vit_h/vit_l/vit_b)
# 模型注册
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# 模型转入设备
sam.to(device=device)
'''
步骤三 将模型加入预测器
'''
# SAM模型预测器
predictor = SamPredictor(sam)
'''
步骤四 输入-提示符
'''
'''
提示方式1:输入-点坐标
'''
# 确定输入点位
input_point = np.array([[140, 180]]) # (x y)
# 确定样本标签类型,1为正样本,0为负样本
input_label = np.array([1])
# 显示图像
plt.figure(figsize=(10, 10))
plt.imshow(image)
# 显示坐标
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
'''
步骤五 把图像输入到预测器中
'''
# 对预测器输入图像
predictor.set_image(image)
'''
步骤六 预测
'''
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
'''
步骤七 显示预测结果
'''
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
'''
方式2:输入提示:区域
'''
# 确定输入区域框左上角点和右下角点 x1 y1 x2 y2
input_box = np.array([100, 50, 260, 260]) # 左上角 右下角点
# 预测
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :], # 输入区域框
multimask_output=False,
)
# 显示预测结果
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
二 FastSAM
https://github.com/CASIA-IVA-Lab/FastSAM
# Everything mode
python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg
# Text prompt
python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --text_prompt "the yellow dog"
# Box prompt (xywh)
python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --box_prompt "[[570,200,230,400]]"
# Points prompt
python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --point_prompt "[[520,360],[620,300]]" --point_label "[1,0]"
示例:
from fastsam import FastSAM, FastSAMPrompt
model = FastSAM('./weights/FastSAM.pt')
IMAGE_PATH = './images/dogs.jpg'
DEVICE = 'cpu'
everything_results = model(IMAGE_PATH, device=DEVICE, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
# everything prompt
ann = prompt_process.everything_prompt()
# bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
ann = prompt_process.box_prompt(bbox=[[200, 200, 300, 300]])
# text prompt
ann = prompt_process.text_prompt(text='a photo of a dog')
# point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
prompt_process.plot(annotations=ann,output_path='./output/dog.jpg',)
!python FastSAM/Inference.py --model_path FastSAM.pt --img_path 1.jpg --imgsz 512 --point_prompt "[[140,180]]" --point_label "[1]"
!python FastSAM/Inference.py --model_path FastSAM.pt --img_path 1.jpg --imgsz 512 --point_prompt "[[140,180],[1,1]]" --point_label "[1,0]"