My code is shown below:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("../..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
image = cv2.imread('image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h,w,c=image.shape
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
def show_anns(image, anns):
# print(f"anns:{anns}")
print(f"anns segmentation:{anns[0]['segmentation']}")
print(f"anns segmentation shape:{anns[0]['segmentation'].shape}")
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
keys='segmentation'
img = np.ones((sorted_anns[0][keys].shape[0], sorted_anns[0][keys].shape[1], 3))
img[:,:,2] = 0
j = 0
for ann in sorted_anns:
j = j+1
m = ann[keys]
m = np.uint8(m)
m2 = (m * 255)
print(m.shape)
# myarray=[255,255,255]
# color_mask = np.concatenate([myarray, [1]])
# print(f"color_mask shape:{color_mask.shape}")
# print(f"img[m] shape:{img[m].shape}")
# img[m] = color_mask
h, w, c = image.shape
masked_img = np.zeros([h, w, 3], dtype=np.uint8)
for i in range(3):
masked_img[:, :, i] = image[:, :, i] * m
# masked_img_rgb = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
# plt.imshow(masked_img_rgb)
masked_img = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
thresh = cv2.Canny(m2, 128, 256)#边缘检测
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
img = np.copy(masked_img) # 获取图像的最小矩形框
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
# 绘制矩形
cropped_image = img[y:y + h, x:x + w] # 根据图像获得的最小矩形框进行分割
cv2.imwrite('cropped'+str(j)+'.jpg', cropped_image)
# ax.imshow(img)
print(len(masks))
print(masks[0].keys())
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=4,
pred_iou_thresh=0.92,
stability_score_thresh=0.97,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=1000, # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
len(masks2)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(image, masks2)
plt.axis('off')
plt.show()