SAM(Segment Anything Model)大模型使用--point prompt

本文介绍了如何在代码中运用SAM(Segment Anything Model)进行点提示(point prompt)的图像分割任务。SAM基于视觉Transformer架构,旨在建立高性能图像分割模型。通过hugging face库,可以更方便地使用预训练的SAM模型。文章展示了使用预训练模型进行预测并可视化结果的过程,为后续的框提示使用奠定了基础。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

概述

本系列将做一个专题,主要关于介绍如何在代码上运行并使用SAM模型以及如何用自己的数据集微调SAM模型,也是本人的毕设内容,这是一个持续更新系列,欢迎大家关注~


SAM(Segment Anything Model)

SAM基于visual transformer架构,希望通过这个大一统的框架建立起在图像分割领域的高性能模型,由于在图像分割领域可用的训练数据较为缺乏,团队在数据的获取训练这一块任务也专门设计的对应的流程。在SAM的论文中,主要把SAM模型的构建分成了三个部分,分别是任务、模型和数据。


模型使用

有一个Meta发布了一个网站demo,感兴趣的朋友可以根据网站的指导进行使用,对SAM的效果有一个直观的感受Segment Anything | Meta AI (segment-anything.com)

本篇博客主要介绍使用hugging face中封装好的函数对SAM进行point prompt的分割任务,对比github上SAM的源码使用,hugging face的函数更加方便使用

依赖的环境库

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor

从transfomers库上下载预训练好的SAM模型,一般保存在C盘下用户的.cache文件下

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

这里的sam-vit-base是比较小的权重,有300多M,还有facebook/sam-vit-huge有1G多,这里使用前者进行使用的演示

定义可视化的函数

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_masks_on_image(raw_image, masks, scores):
    if len(masks.shape) == 4:
      masks = masks.squeeze()
    if scores.shape[0] == 1:
      scores = scores.squeeze()

    nb_predictions = scores.shape[-1]
    fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))

    for i, (mask, score) in enumerate(zip(masks, scores)):
      mask = mask.cpu().detach()
      axes[i].imshow(np.array(raw_image))
      show_mask(mask, axes[i])
      axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
      axes[i].axis("off")

我们使用下面这张图像来进行演示

可视化一个我们的图像和point prompt的位置

raw_image = Image.open(r'D:\CSDN_point\3_11_model\yunnan.jpg')

input_points = [[[900, 1050]]]
show_points_on_image(raw_image, input_points[0])

进行mask的预测,没有进行参数限制的话输出的图片为3张,并且对应有iou的预测得分

inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)


with torch.no_grad():
    outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores

show_masks_on_image(raw_image, masks[0], scores)

 总结

下面就是简单地使用点提示在代码上进行SAM模型的使用,在微调的任务中,普遍认为框提示的效果会比点提示好,所以下一节我们来介绍一下使用框提示的内容

欢迎大家讨论交流~

### Grounding DINO 与 SAM 集成方法概述 在计算机视觉任务中,Grounding DINO 和 Segment Anything Model (SAM) 是两种强大的工具。Grounding DINO 可以通过自然语言描述定位图像中的目标区域[^2],而 SAM 则能够高效地生成高质量的分割掩码[^1]。两者的结合可以显著提升涉及语义理解的任务性能。 #### 方法说明 一种常见的集成方式是先使用 Grounding DINO 提取感兴趣的目标边界框或关键点位置,随后将这些作为输入传递给 SAM 进行精细化分割。具体流程如下: - **文本到图像映射**:利用 Grounding DINO 的能力,接收一段文字描述并返回对应的检测框坐标以及置信度分数。 - **分割细化处理**:把上述得到的结果送入预训练好的 SAM 模型实例化对象,在此基础上完成像素级精确划分操作。 以下是 Python 中如何调用这两个库的一个简单例子: ```python from groundingdino.util.inference import load_model, predict import torch from segment_anything import sam_model_registry, SamPredictor # 加载模型 grounding_dino = load_model("path/to/GroundingDINO/weight.pth") sam_checkpoint = "path/to/sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) predictor = SamPredictor(sam) image_path = 'example.jpg' text_prompt = "a photo of a dog." boxes, logits, phrases = predict( model=grounding_dino, image=image_path, caption=text_prompt, box_threshold=0.35, text_threshold=0.25 ) for i in range(len(boxes)): x_min, y_min, x_max, y_max = boxes[i].numpy() predictor.set_image(image) input_box = np.array([x_min, y_min, x_max, y_max]) masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=input_box[None,:], multimask_output=False) ``` 此脚本展示了从加载必要的权重文件开始直到最终获得二值掩膜的过程。 #### 技术优势分析 这种组合不仅继承了各自单独使用的优点还进一步增强了整体系统的灵活性和鲁棒性。例如,在电商领域应用时,面对复杂背景下的产品图片,仅依靠传统边缘检测算法可能难以达到理想效果;但如果引入该方案,则可以通过简单的指令快速准确地标记出主体轮廓以便后续编辑加工。 另外值得注意的是,尽管当前版本已经表现出色但仍存在改进空间——比如当遇到多个相似类别物体共存于同一画面内时可能会出现混淆现象等问题待解决。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值