前提:
用labelimg对图片进行标注(格式选择voc),得到label的xml文件。
1.单目标分割
1.1 单目标分割代码
import torch
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import os
import glob
import xml.etree.ElementTree as ET
checkpoint = "./weight/sam_vit_h_4b8939.pth" # github上下载权重路径
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
image_dir = r"D:\Desktop\mult_test\images"
# 获取图片目录下的所有图片文件路径
image_files = glob.glob(os.path.join(image_dir, '*.jpg')) # 根据实际图片格式进行修改
# 保存生成mask的路径
save_dir = r"D:\Desktop\mult_test\mask"
# 注释(标签)文件目录路径
xml_dir = r'D:\Desktop\mult_test\label'
# 遍历图片文件
for image_file in image_files:
image = cv2.imread(image_file)
predictor.set_image(image)
# 获取图片文件名(不包含扩展名)
image_filename = os.path.splitext(os.path.basename(image_file))[0]
# 构建注释文件路径
xml_file = os.path.join(xml_dir,image_filename + '.xml')
tree = ET.parse(xml_file)
root = tree.getroot()
# 遍历 XML 标注文件中的目标对象
for object_elem in root.findall('object'):
# 获取目标对象的边界框坐标
bbox_elem = object_elem.find('bndbox')
xmin = int(bbox_elem.find('xmin').text)
ymin = int(bbox_elem.find('ymin').text)
xmax = int(bbox_elem.find('xmax').text)
ymax = int(bbox_elem.find('ymax').text)
input_box = np.array([xmin,ymin,xmax,ymax])
# 分割返回mask
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box = input_box,
multimask_output=False,
)
# 获取当前文件名
image_filename = os.path.basename(image_file)
# 生成黑白掩码写入文件夹
cv2.imwrite(os.path.join(save_dir, image_filename), np.where(masks[0, :, :] == 1, 0, 1) * 255)
1.2 单目标效果图
2.多目标分割
2.1多目标分割代码
import torch
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import os
import glob
import xml.etree.ElementTree as ET
checkpoint = "./weight/sam_vit_h_4b8939.pth" # 权重路径
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
image_dir = r"D:\Desktop\mult_test\images"
# 获取图片目录下的所有图片文件路径
image_files = glob.glob(os.path.join(image_dir, '*.jpg')) # 根据实际图片格式进行修改
save_dir = r"D:\Desktop\mult_test\mask"
# 注释文件目录路径
xml_dir = r'D:\Desktop\mult_test\label'
# 遍历图片文件
for image_file in image_files:
image = cv2.imread(image_file)
predictor.set_image(image)
# 获取图片文件名(不包含扩展名)
image_filename = os.path.splitext(os.path.basename(image_file))[0]
# 构建注释文件路径
xml_file = os.path.join(xml_dir,image_filename + '.xml')
tree = ET.parse(xml_file)
root = tree.getroot()
data_list = []
# 遍历 XML 标注文件中的目标对象
for object_elem in root.findall('object'):
# 获取目标对象的边界框坐标
bbox_elem = object_elem.find('bndbox')
xmin = int(bbox_elem.find('xmin').text)
ymin = int(bbox_elem.find('ymin').text)
xmax = int(bbox_elem.find('xmax').text)
ymax = int(bbox_elem.find('ymax').text)
data = [xmin,ymin,xmax,ymax]
data_list.append(data) # 将一张图片中的每一个object添加进一个列表
input_boxes = torch.tensor(data_list, device=predictor.device) # 将列表转为张量
# 官方给的多目标分割代码
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
count = 0
pre_mask = 0
mask = 0
# 依次遍历masks张量
for mask in masks:
if count == 0: # 得到第一张掩码
pre_mask = np.where(mask.cpu().numpy()[0, :, :] == 1, 0, 1) * 255
count = count + 1
else:
# 将下一张掩码与第一张掩码相与 将掩码加到第一张图片上
mask = pre_mask & np.where(mask.cpu().numpy()[0, :, :] == 1, 0, 1) * 255
# 赋值
pre_mask = mask
# 得到当前文件名
image_filename = os.path.basename(image_file)
# 写入文件
cv2.imwrite(os.path.join(save_dir, image_filename), mask)