如何利用SAM(segment-anything)制作自己的分割数据集

1. 环境搭建

        github地址 GitHub - facebookresearch/segment-anything: The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.

        1.2 步骤        

        

该代码需要 python>=3.8,以及 pytorch>=1.7 和 torchvision>=0.8。 请按照此处的说明安装 PyTorch 和 TorchVision 依赖项。 强烈建议安装支持 CUDA 的 PyTorch 和 TorchVision。

安装 Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

或在本地克隆存储库并安装

git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

以下可选依赖项对于掩模后处理、以 COCO 格式保存掩模、示例笔记本以及以 ONNX 格式导出模型是必需的。 运行示例笔记本还需要 jupyter。

pip install opencv-python pycocotools matplotlib onnxruntime onnx

 2. 制作蒙版

相关代码如下

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
# 添加掩码
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns[10:200]:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

image = cv2.imread('./images/05.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(image.shape)


plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

import sys
sys.path.append("..")
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../weights/mobile_sam.pt"
model_type = "vit_t"

device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
sam.eval()

mask_generator = SamAutomaticMaskGenerator(sam)

masks = mask_generator.generate(image)

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks[:])
plt.axis('off')
plt.show() 

# 保存掩码
def save_mask(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x :x['area']), reverse=False)
    img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 1))

    for ann in sorted_anns[:]:
        m = ann['segmentation']
        img[m] = 255
    
    cv2.imwrite('res.jpg', img)


# save_mask(masks)
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
save_mask(sorted_anns[:])

3. 制作COCO格式数据集可用来语义分割、目标检测、实例分割

接蒙版代码

# 获取边缘
import cv2
import numpy as np
image = cv2.imread('./images/05.jpg')
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
img_ = np.zeros_like(image)
gray_images = mask_show(masks[:])
for img in gray_images[:]:
    gray_image = np.uint8(img)
    gray_image = cv2.morphologyEx(gray_image,cv2.MORPH_OPEN,kernel)
    edges = cv2.Canny(gray_image, 50, 150)
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(img_, contours, -1, (255, 255, 255), 2)
cv2.imwrite("counte2.png", img_)
# 蒙版-边缘
im = cv2.imread('images/05.jpg', cv2.IMREAD_GRAYSCALE)
image1 = cv2.imread('res.jpg', cv2.IMREAD_GRAYSCALE)
image2 = cv2.imread('counte2.png', cv2.IMREAD_GRAYSCALE)
img = cv2.subtract(image1, image2)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(7,7))
dst2 = cv2.morphologyEx(img,cv2.MORPH_OPEN,kernel)

# print(dst2.shape)

re_img = cv2.addWeighted(dst2, 0.2, im, 0.8 ,0)
cv2.imwrite("res3.jpg", dst2)

plt.figure(figsize=(20,20))
plt.imshow(dst2, cmap='gray')
plt.axis('off')
plt.show()

# 以COCO格式存储
import json
orig_img = cv2.imread('./images/05.jpg')
image = cv2.imread('res3.jpg')
edges = cv2.Canny(image, 50, 150)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
images = [
    {
        'file_name':'05.jpg',
        'height': int(orig_img.shape[0]),
        'width' : int(orig_img.shape[1]),
        'id': 1
    },
]

categories = [
    {
        'id': 1,
        'name': 'qituan'
    },
]
annotations= []
for contour in contours:
    seg = []
    contour_ = contour.squeeze(1)
    seg.append(list(contour_.flatten().tolist()))
    x, y, w, h = cv2.boundingRect(contour)
    bbox = [x, y, w, h]
    area = cv2.contourArea(contour)
    iscrowd = 0
    image_id = 1
    category_id = 1
    id = len(annotations) + 1
    annotations.append({
        'segmentation': seg,
        'bbox': bbox,
        'area': area,
        'iscrowd': 0,
        'image_id': 1,
        'category_id': 1,
        'id': id
    })


coco_data = {
    'images': images,
    'annotations': annotations,
    'categories': categories
}

print(coco_data)

output_file_path = 'coco_data.json'

# Serialize the data and write to a JSON file
with open(output_file_path, 'w') as f:
    json.dump(coco_data, f, indent=4)

4. 验证COCO格式数据


import cv2
import random
import json, os
from pycocotools.coco import COCO
from skimage import io
from matplotlib import pyplot as plt
import numpy as np

train_json = 'coco_data.json'
train_path = './images/'
coco = COCO(train_json)

list_imgIds = coco.getImgIds(catIds=1 )
list_imgIds

img = coco.loadImgs(list_imgIds[0])[0]
image = cv2.imread(train_path + img['file_name'])  # 读取图像
img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
anns = coco.loadAnns(img_annIds)
img = coco.loadImgs(list_imgIds[0])[0]
img1 = cv2.imread(train_path + img['file_name'])  # 读取图像
#分割
for ann in anns:

    data = np.array(ann['segmentation'][0])
    num_points = len(data) // 2
    contour_restored = data.reshape((num_points, 2))
    contour_restored = contour_restored.reshape(contour_restored.shape[0], 1, contour_restored.shape[1])
    # print(contour_restored.shape)
    color = np.random.randint(0, 255, 3).tolist()
    cv2.drawContours(img1, [contour_restored], -1, tuple(color), thickness=cv2.FILLED)

    # mask = coco.annToMask(ann)
    # color = np.random.randint(0, 255, 3)  # Random color for each mask
    # img = cv2.addWeighted(img, 1, cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR), 0.5, 0)
    
plt.rcParams['figure.figsize'] = (20.0, 20.0)
    # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
plt.imshow(img1)
plt.show()


img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
img_annIds
# 目标检测
for id in img_annIds[:]:
    ann = coco.loadAnns(id)[0]
    x, y, w, h = ann['bbox']
    # print(ann['bbox'])
    image1 = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)

plt.rcParams['figure.figsize'] = (20.0, 20.0)
    # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
plt.imshow(image1)
plt.show()

  • 16
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值