Grounded-SAM 改批量处理图片并保存(附代码)

源文件(未改)

# make dir
    os.makedirs(output_dir, exist_ok=True)
    # load image
    image_pil, image = load_image(image_path)
    # load model
    model = load_model(config_file, grounded_checkpoint, device=device)

    # visualize raw image
    image_pil.save(os.path.join(output_dir, "raw_image.jpg"))

    # run grounding dino model
    boxes_filt, pred_phrases = get_grounding_output(
        model, image, text_prompt, box_threshold, text_threshold, device=device
    )

    # initialize SAM
    predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    size = image_pil.size
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
        boxes_filt[i][2:] += boxes_filt[i][:2]

    boxes_filt = boxes_filt.cpu()
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
###此处有一断点
    masks, _, _ = predictor.predict_torch(
        point_coords = None,
        point_labels = None,
        boxes = transformed_boxes.to(device),
        multimask_output = False,
    )
    
    # draw output image
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box, label in zip(boxes_filt, pred_phrases):
        show_box(box.numpy(), plt.gca(), label)

    plt.axis('off')
    plt.savefig(
        os.path.join(output_dir, "grounded_sam_output.jpg"), 
        bbox_inches="tight", dpi=300, pad_inches=0.0
    )
###有一断点
    save_mask_data(output_dir, masks, boxes_filt, pred_phrases)

修改后

# make dir
os.makedirs(output_dir, exist_ok=True)

# load model
model = load_model(config_file, grounded_checkpoint, device=device)
# iterate over input images
input_files = [f for f in os.listdir(input_dir) if f.endswith('.jpg') or f.endswith('.png')]
for idx, input_file in enumerate(input_files):
    print(f'Processing file {idx + 1} of {len(input_files)}: {input_file}')
    # load image
    image_pil, image = load_image(os.path.join(input_dir, input_file))
    # run grounding dino model
    boxes_filt, pred_phrases = get_grounding_output(
        model, image, text_prompt, box_threshold, text_threshold, device=device
    )
    # initialize SAM
    predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
    image = cv2.imread(os.path.join(input_dir, input_file))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    size = image_pil.size
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
        boxes_filt[i][2:] += boxes_filt[i][:2]

    boxes_filt = boxes_filt.cpu()

    masks_list = []
    transformed_boxes_list = []
    for i in range(0, len(boxes_filt), batch_size):
        batch_boxes = boxes_filt[i:i + batch_size]
        transformed_boxes = predictor.transform.apply_boxes_torch(batch_boxes, image.shape[:2]).to(device)
        transformed_boxes_list.append(transformed_boxes)
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes.to(device),
            multimask_output=False,
        )
        masks_list.append(masks)

    transformed_boxes = torch.cat(transformed_boxes_list, dim=0)
    masks = torch.cat(masks_list, dim=0)

    # draw output image
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box, label in zip(boxes_filt, pred_phrases):
        show_box(box.numpy(), plt.gca(), label)

    plt.axis('off')
    plt.savefig(
        os.path.join(output_dir, f"grounded_sam_output_{input_file}"),
        bbox_inches="tight", dpi=300, pad_inches=0.0
    )
    save_mask_data(output_dir, masks, boxes_filt, pred_phrases, os.path.splitext(input_file)[0])

代码解释

  1. os.makedirs(output_dir, exist_ok=True):创建一个目录 output_dir,如果该目录已经存在则忽略。

  2. model = load_model(config_file, grounded_checkpoint, device=device):从指定的配置文件 config_file 和检查点文件 grounded_checkpoint 中加载模型,并指定计算设备 device

  3. input_files = [f for f in os.listdir(input_dir) if f.endswith('.jpg') or f.endswith('.png')]:列出输入目录 input_dir 中所有以 .jpg 或 .png 结尾的文件,并存储在 input_files 列表中。

  4. for idx, input_file in enumerate(input_files)::遍历 input_files 列表中的每个输入文件。

  5. image_pil, image = load_image(os.path.join(input_dir, input_file)):使用 load_image 函数从指定路径读取图像文件,并返回一个PIL图像对象 image_pil 和一个NumPy数组 image

  6. boxes_filt, pred_phrases = get_grounding_output(...):使用 get_grounding_output 函数对图像 image 进行模型预测,得到边界框 boxes_filt 和预测短语 pred_phrases

  7. predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device)):根据指定的SAM模型检查点文件 sam_checkpoint 构建一个SAM预测器 predictor

  8. image = cv2.imread(os.path.join(input_dir, input_file)):使用OpenCV库的imread函数读取图像文件,并将其存储在变量 image 中。

  9. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB):将图像的颜色空间从BGR转换为RGB格式。

  10. 对边界框进行相应的变换和处理操作。

  11. plt.figure(figsize=(10, 10)):创建一个大小为10x10的新图像窗口。

  12. plt.imshow(image):在图像窗口中显示图像。

  13. for mask in masks::遍历masks列表中的每个掩码。

  14. show_mask(mask.cpu().numpy(), plt.gca(), random_color=True):使用show_mask函数在图像窗口中显示掩码。

  15. for box, label in zip(boxes_filt, pred_phrases)::使用zip函数将boxes_filtpred_phrases进行配对。

  16. show_box(box.numpy(), plt.gca(), label):使用show_box函数在图像窗口中显示边界框和相应的标签。

  17. plt.axis('off'):关闭图像窗口的坐标轴显示。

  18. plt.savefig(os.path.join(output_dir, f"grounded_sam_output_{input_file}"), bbox_inches="tight", dpi=300, pad_inches=0.0):将图像窗口以指定的文件名保存到输出目录 output_dir 中。

  19. save_mask_data(output_dir, masks, boxes_filt, pred_phrases, os.path.splitext(input_file)[0]):使用save_mask_data函数保存掩码数据、边界框和预测短语数据到输出目录 output_dir 中。

代码功能是对输入目录中的图像文件进行模型预测和图像处理,并将处理后的结果批量保存到输出目录中。

输出

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值