资源获取
论文:https://ai.facebook.com/research/publications/segment-anything/
百度云链接:https://pan.baidu.com/s/1wLC14RExHVEw2Z2AVlxcsQ(提取码:1234)
github入口:https://github.com/facebookresearch/segment-anything
权重百度云链接:链接:https://pan.baidu.com/s/1X8UNtZD3gh3Mg5OH8h7mQg(提取码:lr54)
算法原理图
模型简介
Segment Anything模型 (SAM) 可根据输入提示(例如点或框)生成高质量的对象掩码,它可用于为图像中的所有对象生成掩码。 它已经在 1100 万张图像和 11 亿个掩码的数据集上进行了训练,并且在各种分割任务上具有很强的“零样本”性能!
安装
step1:
在python>=3.8, pytorch>=1.7,torchvision>=0.8环境中执行:
git clone git@github.com:facebookresearch/segment-anything.git
step2:
cd segment-anything
pip install -e .
测试代码
from segment_anything import sam_model_registry,SamAutomaticMaskGenerator,build_sam, SamPredictor
from osgeo import gdal
import matplotlib.pyplot as plt
import cv2 as cv
import numpy as np
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
sam_checkpoint = "sam_vit_h_4b8939.pth"
device = "cuda"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# 定义参数并初始化生成器
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=60, # Requires open-cv to run post-processing
)
#读取图片(高宽为1024×1024)
img_path='./test/test.jpg'
img=cv.imread(img_path)
#开始预测
masks = mask_generator_2.generate(img)
print(len(masks))
print(masks[0].keys())
#在原图上绘制mask,并保存至指定路径
plt.figure(figsize=(20,20))
plt.imshow(img)
show_anns(masks)
plt.axis('off')
plt.savefig(fname="pic_test.png")