源文件(未改)
# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
image_pil, image = load_image(image_path)
# load model
model = load_model(config_file, grounded_checkpoint, device=device)
# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)
# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
###此处有一断点
masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)
plt.axis('off')
plt.savefig(
os.path.join(output_dir, "grounded_sam_output.jpg"),
bbox_inches="tight", dpi=300, pad_inches=0.0
)
###有一断点
save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
修改后
# make dir
os.makedirs(output_dir, exist_ok=True)
# load model
model = load_model(config_file, grounded_checkpoint, device=device)
# iterate over input images
input_files = [f for f in os.listdir(input_dir) if f.endswith('.jpg') or f.endswith('.png')]
for idx, input_file in enumerate(input_files):
print(f'Processing file {idx + 1} of {len(input_files)}: {input_file}')
# load image
image_pil, image = load_image(os.path.join(input_dir, input_file))
# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)
# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
image = cv2.imread(os.path.join(input_dir, input_file))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
masks_list = []
transformed_boxes_list = []
for i in range(0, len(boxes_filt), batch_size):
batch_boxes = boxes_filt[i:i + batch_size]
transformed_boxes = predictor.transform.apply_boxes_torch(batch_boxes, image.shape[:2]).to(device)
transformed_boxes_list.append(transformed_boxes)
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(device),
multimask_output=False,
)
masks_list.append(masks)
transformed_boxes = torch.cat(transformed_boxes_list, dim=0)
masks = torch.cat(masks_list, dim=0)
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)
plt.axis('off')
plt.savefig(
os.path.join(output_dir, f"grounded_sam_output_{input_file}"),
bbox_inches="tight", dpi=300, pad_inches=0.0
)
save_mask_data(output_dir, masks, boxes_filt, pred_phrases, os.path.splitext(input_file)[0])
代码解释
-
os.makedirs(output_dir, exist_ok=True)
:创建一个目录output_dir
,如果该目录已经存在则忽略。 -
model = load_model(config_file, grounded_checkpoint, device=device)
:从指定的配置文件config_file
和检查点文件grounded_checkpoint
中加载模型,并指定计算设备device
。 -
input_files = [f for f in os.listdir(input_dir) if f.endswith('.jpg') or f.endswith('.png')]
:列出输入目录input_dir
中所有以.jpg
或.png
结尾的文件,并存储在input_files
列表中。 -
for idx, input_file in enumerate(input_files):
:遍历input_files
列表中的每个输入文件。 -
image_pil, image = load_image(os.path.join(input_dir, input_file))
:使用load_image
函数从指定路径读取图像文件,并返回一个PIL图像对象image_pil
和一个NumPy数组image
。 -
boxes_filt, pred_phrases = get_grounding_output(...)
:使用get_grounding_output
函数对图像image
进行模型预测,得到边界框boxes_filt
和预测短语pred_phrases
。 -
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
:根据指定的SAM模型检查点文件sam_checkpoint
构建一个SAM预测器predictor
。 -
image = cv2.imread(os.path.join(input_dir, input_file))
:使用OpenCV库的imread
函数读取图像文件,并将其存储在变量image
中。 -
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
:将图像的颜色空间从BGR转换为RGB格式。 -
对边界框进行相应的变换和处理操作。
-
plt.figure(figsize=(10, 10))
:创建一个大小为10x10的新图像窗口。 -
plt.imshow(image)
:在图像窗口中显示图像。 -
for mask in masks:
:遍历masks
列表中的每个掩码。 -
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
:使用show_mask
函数在图像窗口中显示掩码。 -
for box, label in zip(boxes_filt, pred_phrases):
:使用zip
函数将boxes_filt
和pred_phrases
进行配对。 -
show_box(box.numpy(), plt.gca(), label)
:使用show_box
函数在图像窗口中显示边界框和相应的标签。 -
plt.axis('off')
:关闭图像窗口的坐标轴显示。 -
plt.savefig(os.path.join(output_dir, f"grounded_sam_output_{input_file}"), bbox_inches="tight", dpi=300, pad_inches=0.0)
:将图像窗口以指定的文件名保存到输出目录output_dir
中。 -
save_mask_data(output_dir, masks, boxes_filt, pred_phrases, os.path.splitext(input_file)[0])
:使用save_mask_data
函数保存掩码数据、边界框和预测短语数据到输出目录output_dir
中。
代码功能是对输入目录中的图像文件进行模型预测和图像处理,并将处理后的结果批量保存到输出目录中。