计算机视觉 | YOLO 和 SAM 强强联合能干什么大事

点击下方卡片,关注“小白玩转Python”公众号

在这篇博客中,我们将探索计算机视觉和图像分析的迷人领域,探讨两种开创性模型之间的动态协同:YOLO(You Only Look Once)和 SAM(Segment Anything Model)。YOLO 因其在目标检测方面的革命性进展而备受赞誉,与在分割领域具有强大实力的 SAM 相结合,承诺带来令人兴奋的能力融合。

那么,什么是 SAM(Segment Anything Model)?:

SAM 由 Meta 在 2023 年推出,是一种革命性的图像分割模型。以其卓越的性能而著称,SAM 已巩固了其作为最先进的分割模型之一的地位。SAM 在图像分割技术方面代表了一项突破性的进展,提供了前所未有的精确性和多功能性。与传统的受特定对象类型或环境限制的分割模型不同,SAM 凭借先进的神经网络架构和在大型数据集上进行的广泛训练,能够以无与伦比的准确性分割图像中的几乎任何对象。

架构:

如论文中的图表所示,SAM 的架构采用多阶段的图像分割方法。其核心是一系列互联的神经网络模块,每个模块都针对分割过程的不同方面进行处理。

f058ea5fc4c78520ec5867305589a25a.png

架构的初始阶段涉及特征提取,输入图像通过卷积层处理以提取相关特征。这些特征随后通过一系列编码和解码层传递,在提取高层语义信息的同时保留空间细节。SAM 的关键创新在于其注意力机制,使模型在分割过程中能够有选择地关注图像中的相关区域。这种注意力机制通过一组注意力模块实现,基于上下文线索和特征重要性动态调整模型的关注点。

此外,SAM 还引入了跳跃连接,以促进不同网络层之间的信息流动。这些连接使模型能够利用低层和高层特征,增强其捕捉复杂细节和上下文的能力。总体而言,SAM 的架构经过精心设计,优化了分割过程,利用注意力机制和跳跃连接等先进的神经网络技术,实现了精确且多功能的分割结果。这种复杂的设计使 SAM 在广泛的应用中表现出色,从医学成像到自动驾驶,确立了其在计算机视觉领域的开创性地位。

SAM 的关键特性:

  • 多功能性:SAM 设计用于分割图像中的任何事物,从日常物品到复杂场景,具备出色的准确性和细节。

  • 鲁棒性:得益于其复杂的架构和在多样化数据集上的广泛训练,SAM 在各种场景中表现出色,包括不同的光照条件和物体方向。

  • 规模:SAM 能够处理不同分辨率和规模的图像,适用于高分辨率图像和实时应用。

  • 上下文理解:SAM 整合了上下文信息,以提高分割精度,甚至在杂乱场景中也能有效地区分对象与其环境。

  • 效率:尽管具备先进功能,SAM 仍保持高效率,确保快速处理速度,非常适合实时应用。

  • 适应性:SAM 可以为特定任务或数据集进行微调和定制,允许无缝集成到各种应用和行业中。

论文:https://arxiv.org/pdf/2304.02643.pdf

现在,让我们深入探讨如何将 YOLO 与 SAM 嵌入在一起。但是,为什么我们需要将这两个模型结合在一起?

将 YOLO(You Only Look Once)与 SAM(Segment Anything Model)结合起来,提供了一个强大的协同效应,增强了两个模型的能力。 YOLO 在快速识别图像中的对象方面表现出色,而 SAM 在高精度分割对象方面具有优势。通过将 YOLO 与 SAM 嵌入在一起,我们可以利用这两个模型的优势,实现更全面和准确的图像分析。这种集成不仅可以检测对象,还可以精确地描绘它们的边界,为下游任务提供更丰富的上下文信息。此外,将 YOLO 与 SAM 嵌入在一起,可以更稳健和高效地处理复杂的视觉数据,在自动驾驶、医学成像和监控系统等应用中具有不可估量的价值。

实施 SAM 处理图像:

步骤1:首先从 GitHub 仓库下载 SAM 模型。

import os
HOME = os.getcwd()
pip install roboflow ultralytics 'git+https://github.com/facebookresearch/segment-anything.git'

步骤2:安装 SAM 模型的权重,可以从 SAM 的 GitHub 仓库获取。

%cd {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

步骤3:验证是否已成功下载 SAM 权重文件。

import os


CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

步骤4:加载模型。

import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"


sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

步骤5:初始化掩码生成器。

mask_generator = SamAutomaticMaskGenerator(sam)

步骤6:为图像生成掩码。sam_result 变量包含生成的掩码。

import cv2
import supervision as sv


image_bgr = cv2.imread("path/to/image")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)


sam_result = mask_generator.generate(image_rgb)

掩码生成返回一个包含多个掩码的列表,每个掩码都是一个包含各种数据的字典。这些键包括:

  • Segmentation:掩码

  • area:掩码的像素面积

  • bbox:掩码的边界框,格式为 XYWH

  • predicted_iou:模型对掩码质量的自我预测

  • point_coords:生成该掩码的采样输入点

  • stability_score:掩码质量的附加度量

  • crop_box:用于生成该掩码的图像裁剪框,格式为 XYWH

print(len(masks))
print(sam_results[0].keys())

步骤7:结果

mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)


detections = sv.Detections.from_sam(sam_result=sam_result)


annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)


sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

c6f5cf0211f9981502c4225d14f016a7.png

现在,实施 YOLO+SAM 处理视频:

步骤1:下载 YOLO、SAM 权重和其他依赖项。

from ultralytics import YOLO


from IPython.display import display, Image


model = YOLO(MODEL)
model.fuse()
%cd {HOME}


import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!mkdir {HOME}/weights


%cd {HOME}/weights


!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

步骤2:确保正确安装了权重,并用掩码生成器初始化 SAM。

import os


CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_h"](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)

步骤3:嵌入 YOLO 检测和 SAM 掩码。

CLASS_NAMES_DICT = model.model.names


# class_ids of interest - based on the number of classses
CLASS_ID = [item for item in range(0,len(CLASS_NAMES_DICT))]


CLASS_NAMES_DICT
import cv2
import numpy as np
import torch


# Replace the following line with your actual VIDEO_PATH
VIDEO_PATH = "/path/to/input_video"
OUTPUT_VIDEO_PATH = "/path/to/save/output_video"


# This will contain the resulting mask predictions for local use
mask_frames = []


def get_video_dimensions(cap):
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    return width, height


def add_color_to_mask(mask, color):
    # Convert the color tensor to CPU
    color = torch.tensor(color).cpu().numpy()


    # Create a binary mask based on the original mask
    color_mask = np.zeros_like(mask.cpu().numpy(), dtype=np.uint8)
    color_mask[mask.cpu().numpy() > 0] = 1  # Set non-zero values to 1


    # Expand the color tensor and apply it to the binary mask
    colored_mask = color_mask[..., None] * color


    return colored_mask


def draw_class_names(frame, class_names, positions, color, font_size=0.5):
    for class_name, position in zip(class_names, positions):
        cv2.putText(frame, class_name, position, cv2.FONT_HERSHEY_SIMPLEX, font_size, color, 2, cv2.LINE_AA)


def draw_yolov8_boxes(frame, boxes, color):
    for box in boxes:
        box = list(map(int, box))
        cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), color, 2)


constant_mask_color = np.array([0, 0, 255], dtype=np.uint8)  # Red color for masks
output_class_color = (0, 255, 0)  # Green color for class names
yolov8_box_color = (255, 0, 0)  # Blue color for YOLOv8 bounding boxes


cap = cv2.VideoCapture(VIDEO_PATH)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))


fourcc = cv2.VideoWriter_fourcc(*'XVID')
output_video = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, 15.0, (width, height))


frame_num = 1
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break


    # Check if the frame is empty or None
    if frame is None:
        continue  # Skip processing for empty frames


    # Run frame through YOLOv8 to get detections
    detections = model.predict(frame, conf=0.7)


    # Check if there are fish detections
    if len(detections[0].boxes) == 0:
        continue  # Skip processing for frames without fish detections


    # Run frame and detections through SAM to get masks
    transformed_boxes = mask_predictor.transform.apply_boxes_torch(
        detections[0].boxes.xyxy, list(get_video_dimensions(cap))
    )
    mask_predictor.set_image(frame)
    masks, _, _ = mask_predictor.predict_torch(
        boxes=transformed_boxes,
        multimask_output=False,
        point_coords=None,
        point_labels=None
    )


    # Check if the mask is empty
    if masks[0][0].numel() == 0:
        continue  # Skip processing for empty masks


    # Combine mask predictions into a single mask, each with the same color
    class_ids = detections[0].boxes.cpu().cls
    merged_with_colors = add_color_to_mask(masks[0][0], constant_mask_color)
    for i in range(1, len(masks)):
        curr_mask_with_colors = add_color_to_mask(masks[i][0], constant_mask_color)
        merged_with_colors = np.bitwise_or(merged_with_colors, curr_mask_with_colors)


    # Draw YOLOv8 bounding boxes on the frame
    draw_yolov8_boxes(frame, detections[0].boxes.xyxy, yolov8_box_color)


    # Draw class names on the frame with a slightly larger font
    class_names = [CLASS_NAMES_DICT[int(class_id)] for class_id in class_ids]
    draw_class_names(frame, class_names, [(int(box[0]), int(box[1])) for box in detections[0].boxes.xyxy], output_class_color, font_size=0.7)


    # Overlay the SAM masks onto the frame
    frame_with_masks = cv2.addWeighted(frame, 1, merged_with_colors, 0.5, 0)


    # Write the frame with masks, YOLOv8 boxes, and class names to the output video
    output_video.write(frame_with_masks)


    frame_num += 1


cap.release()
output_video.release()
cv2.destroyAllWindows()

701ddf90b023dffdad2678eec3dad6cd.png

上述示例适用于标准 YOLO 训练的数据集。此外,此方法也可用于自定义数据集。

应用:

  • 自动驾驶车辆:增强目标检测和分割能力,确保自动驾驶汽车的安全导航和决策。

  • 医学成像:通过精确识别和分割医学图像中的异常,提高诊断准确性,如 X 光片、MRI 和 CT 扫描。

  • 监控系统:通过精确检测和分割感兴趣的对象,提高公共场所的安全监控。

  • 工业自动化:通过检测和分割装配线上制造产品中的缺陷,优化质量控制过程。

  • 农业:通过精确识别和分割农业图像中的植物和害虫,协助作物监测和害虫检测。

  • 环境监测:通过检测和分割卫星图像中的树木、水体和野生动物,帮助监测和分析环境变化。

  • 增强现实:通过精确检测和分割现实世界中的物体,提升 AR 应用的沉浸式用户体验。

  • 零售分析:通过精确检测和分割零售环境中的产品,改善客户分析和库存管理。

总之,SAM(Segment Anything Model)和 YOLO(You Only Look Once)的融合代表了图像分析领域的重大进步,在各个领域具有深远的影响。这一整合结合了 YOLO 在目标检测方面的敏锐性和 SAM 在分割方面的精确性,使我们能够从视觉数据中获得更深入的见解。从优化自动驾驶车辆的感知系统到帮助医学专家诊断疾病,SAM+YOLO 的协同潜力远远超越了传统边界。

·  END  ·

HAPPY LIFE

60c50eb004d8dce632bf55a08e9cf924.png

本文仅供学习交流使用,如有侵权请联系作者删除

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
需要学习ubuntu系统上YOLOv4的同学请前往:《YOLOv4目标检测实战:原理与源码解析》 【为什么要学习这门课】 Linux创始人Linus Torvalds有一句名言:Talk is cheap. Show me the code. 冗谈不够,放码过来! 代码阅读是从基础到提高的必由之路。尤其对深度学习,许多框架隐藏了神经网络底层的实现,只能在上层调包使用,对其内部原理很难认识清晰,不利于进一步优化和创新。YOLOv4是最近推出的基于深度学习的端到端实时目标检测方法。YOLOv4的实现darknet是使用C语言开发的轻型开源深度学习框架,依赖少,可移植性好,可以作为很好的代码阅读案例,让我们深入探究其实现原理。【课程内容与收获】 本课程将解析YOLOv4的实现原理和源码,具体内容包括:- YOLOv4目标检测原理- 神经网络及darknet的C语言实现,尤其是反向传播的梯度求解和误差计算- 代码阅读工具及方法- 深度学习计算的利器:BLAS和GEMM- GPU的CUDA编程方法及在darknet的应用- YOLOv4的程序流程- YOLOv4各层及关键技术的源码解析本课程将提供注释后的darknet的源码程序文件。【相关课程】 除本课程《Windows版YOLOv4目标检测:原理与源码解析》外,本人推出了有关YOLOv4目标检测的系列课程,包括:《Windows版YOLOv4目标检测实战:训练自己的数据集》《Windows版YOLOv4-Tiny目标检测实战:训练自己的数据集》《Windows版YOLOv4目标检测实战:人脸口罩佩戴检测》《Windows版YOLOv4目标检测实战:中国交通标志识别》建议先学习一门YOLOv4实战课程,对YOLOv4的使用方法了解以后再学习本课程。【YOLOv4网络模型架构图】 下图由白勇老师绘制  

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值