参考自github:segment-anything/automatic_mask_generator_example.ipynb at main · facebookresearch/segment-anything · GitHub官网:https://segment-anything.com/
代码需要python>=3.8
, 以及pytorch>=1.7
和torchvision>=0.8
。请按照此处的说明安装 PyTorch 和 TorchVision 依赖项。强烈建议安装支持 CUDA 的 PyTorch 和 TorchVision。
安装段任何东西:
pip install git+https://github.com/facebookresearch/segment-anything.git
或者在本地克隆存储库并安装
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
以下可选依赖项对于掩码后处理、以 COCO 格式保存掩码、示例笔记本以及以 ONNX 格式导出模型是必需的。jupyter
还需要运行示例笔记本。
pip install opencv-python pycocotools matplotlib onnxruntime onnx
首先下载一个模型checkpoint。然后只需几行就可以使用该模型从给定的提示中获取掩码:
from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
此外,可以从命令行为图像生成遮罩:
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>
模型权重点
该模型的三种模型版本具有不同的骨干尺寸。这些模型可以通过运行来实例化
from segment_anything import sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
单击下面的链接下载相应模型类型的权重,分别是h代表了2.6G,l代表1.2G,b代表375M,注意如果显存不够是无法跑大权重的。
查看自己显存的代码:
grep -i --color memory /var/log/Xorg.0.log
default
或vit_h
:ViT-H SAM 型号。vit_l
: ViT-L SAM 模型。vit_b
: ViT-B SAM 型号。
单张图片分割
使用python代码对单张图片进行全景切割:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# 改变sam_checkpoint,model_type,device为你想要的模型
sam_checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
# 在image处加入要处理的图片路径
image = cv2.imread('/*.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
批量图片处理
批量处理一个文件夹中的图片:
# 输入和输出的路径
input_dir = "路径"
output_dir = "路径"
# 循环读取文件并保存
for image_path in glob.glob(os.path.join(input_dir, "*.jpg")):
# 读取图片
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 生成mask
masks = mask_generator.generate(image)
# 保存带有
output_path = os.path.join(output_dir, os.path.basename(image_path))
save_anns(masks, image, output_path)