Firstly,I used segment_anythin_main to generate an image called “mask”. Then, multiply the original image with the “mask” image, we will get our target image.
Fig.1 Original image
Fig.2 mask image
In order to improve the accuracy of subsequent processing, I use the canny operator to detect the edge of the image of the maskc to get processed image. Use the cv2.boundingRect() function to obtain the relevant parameters (x, y, w, h) of the smallest rectangular frame of the processed image. Use the obtained parameters to crop the segmented image. Finally I got the cropped image.
Fig.3 Segmented image
Fig.4 Cropped image
My code is shown below:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
print('0')
image = cv2.imread('truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h,w,c=image.shape
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
predictor.set_image(image)
input_point = np.array([[300, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
plt.imshow(image)
plt.axis('on')
plt.show()
show_points(input_point, input_label, plt.gca())
print(f"mask shape:{masks.shape}") # (number_of_masks) x H x W
m = masks[2]
m = np.uint8(m)
m2 = (m*255) #单纯白和黑的mask,提高最小矩形框的检测精度
plt.imshow(m2,cmap="gray")
plt.show()
masked_img = np.zeros([h,w,c],dtype=np.uint8)
for i in range(c):
masked_img[:,:,i]=image[:,:,i]*m
masked_img_rgb = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
plt.imshow(masked_img)
cv2.imwrite('diff.jpg',masked_img)
plt.show()
thresh = cv2.Canny(m2, 128, 256)#边缘检测
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
img = np.copy(m2)
img_r = np.copy(masked_img)#获取图像的最小矩形框
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
# 绘制矩形
cropped_image = img_r[y:y+h,x:x+w]#根据图像获得的最小矩形框进行分割
cv2.imwrite('1.jpg',cropped_image)