SAM2视频模块使用(翻译自video_predictor_example.ipynb)

使用 SAM 2 进行视频分割
本笔记本介绍如何使用 SAM 2 进行视频交互式分割。它将涵盖以下内容:

在帧上添加点击,以获取并完善小掩码(时空掩码)
在整个视频中传播点击以获取掩码
同时分割和跟踪多个对象
我们使用分段或掩码来指单个帧上的物体模型预测,使用小掩码来指整个视频中的时空掩码。

如果使用 jupyter 在本地运行,请首先使用软件仓库中的安装说明在您的环境中安装 segment-anything-2。

导入库

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# 为整个 notebook 使用 bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

# 如果 CUDA 设备的属性为 8 或更高版本,则为 Ampere GPU 开启 tfloat32 
# 详细信息参考 https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

加载 SAM 2 视频预测器

from sam2.build_sam import build_sam2_video_predictor

# 指定 sam2 模型的检查点文件路径
sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"

# 指定模型配置文件路径
model_cfg = "sam2_hiera_l.yaml"

# 使用指定的模型配置和检查点文件构建视频预测器
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

 sam2_checkpoint修改为自己的权重文件所在位置

model_cfg修改为与你权重相匹配的,一般在sam2_configs文件夹

选择视频示例
我们假设视频存储为 JPEG 帧列表,文件名为 <frame_index>.jpg。

对于自定义视频,您可以使用 ffmpeg (https://ffmpeg.org/) 提取其 JPEG 帧,如下所示:

ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg' 其中,-q:v 生成高质量的 JPEG 帧。
其中 -q:v 生成高质量的 JPEG 帧,而 -start_number 0 则要求 ffmpeg 从 00000.jpg 开始生成 JPEG 文件。

# `video_dir` 是一个包含 JPEG 帧的目录,文件名格式如 `<frame_index>.jpg`
video_dir = "./videos/bedroom"

# 扫描该目录中的所有 JPEG 帧文件名
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# 查看第一帧视频帧
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

注意:记得加上plt.show()否则可能没有输出

(图见SAM2中video_predictor_example.ipynb,懒的放了)

初始化推理状态
SAM 2 需要有状态的推理来进行交互式视频分割,因此我们需要在这段视频上初始化推理状态。

在初始化过程中,它会加载 video_path 中的所有 JPEG 帧,并将其像素存储在 inference_state 中(如下图进度条所示)。

# 初始化推理状态
inference_state = predictor.init_state(video_path=video_dir)

例 1:分割并跟踪一个对象
注意:如果您之前使用此 inference_state 运行过任何跟踪,请先通过 reset_state 重置它。

(下面的单元格只是为了说明;这里不需要调用 reset_state,因为这个 inference_state 只是刚刚初始化)。

# 重置推理状态
predictor.reset_state(inference_state)

步骤 1:在框架上添加第一次点击
首先,让我们尝试分割左侧的孩子。

在这里,我们通过向 add_new_points API 发送坐标和标签,在 (x, y) = (210, 350) 处添加标签为 1 的正点击。

注意:标签 1 表示正点击(添加一个区域),标签 0 表示负点击(删除一个区域)。

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 1  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 添加一个正点击 (x, y) = (210, 350) 来开始
points = np.array([[210, 350]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)

# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 在当前(交互)帧上显示结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.show()

注意:记得加上plt.show()否则可能没有输出

步骤 2:增加第二次点击以完善预测
嗯,看来虽然我们想要分割左侧的孩子,但模型只预测了短裤的遮罩--这有可能发生,因为单次点击会对目标对象产生歧义。我们可以通过再次点击孩子的上衣来完善这一帧的遮罩。

在这里,我们在 (x, y) = (250, 220) 处进行第二次正面点击,标签为 1,以扩展遮罩。

注意:在调用 add_new_points 时,我们需要发送所有点击及其标签(即不仅仅是最后一次点击)。

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 1  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
# 将所有点击(及其标签)发送到 `add_new_points`
points = np.array([[210, 350], [250, 220]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1, 1], np.int32)

# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 在当前(交互)帧上显示结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.show()

点击第 2 次细化后,我们就能得到第 0 帧上整个儿童的分割蒙版。

第 3 步:传播提示,在整个视频中获取小掩码
为了在整个视频中获取掩码,我们使用 propagate_in_video API 传播提示信息。

# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {}  # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# 每隔几帧渲染一次分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.show()

步骤 4:添加新的提示以进一步完善小掩码
在上面的输出小掩码中,第 150 帧的边界细节似乎存在一些瑕疵。

通过 SAM 2,我们可以交互式地修正模型预测。我们可以在该帧的 (x, y) = (82, 415) 处添加一个标签为 0 的负点击,以完善子掩码。在这里,我们使用不同的 frame_idx 参数调用 add_new_points 应用程序接口,以指示我们要细化的帧索引。

# 设定需要进一步细化的帧索引和对象ID
ann_frame_idx = 150  # 需要细化的帧索引
ann_obj_id = 1  # 与我们交互的对象的唯一ID(可以是任何整数)

# 显示细化前的分割结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx} -- before refinement")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
plt.show()  # 确保图像在细化前显示

# 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
points = np.array([[82, 415]], dtype=np.float32)
# 标签为`1`表示正点击,`0`表示负点击
labels = np.array([0], np.int32)
_, _, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 显示细化后的分割结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx} -- after refinement")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
plt.show()  # 确保图像在细化后显示

第 5 步:传播提示(再次),在整个视频中获取小掩码
让我们更新整个视频的字幕。在此,我们再次调用 propagate_in_video,在添加上述新的细化点击后传播所有提示信息。

# 运行分割传播并在字典中收集结果
video_segments = {}  # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# 每隔几帧渲染分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.show()  # 确保每个渲染的图像显示出来

现在,所有框架上的线段都很美观。

例1完整代码如下:

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
#video_dir = r"E:\segment-anything-2-main\jc-imgs"
# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
#plt.show()

inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (210, 350) to get started
points = np.array([[210, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
#plt.show()

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 1  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
# 将所有点击(及其标签)发送到 `add_new_points`
points = np.array([[210, 350], [250, 220]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1, 1], np.int32)

# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 在当前(交互)帧上显示结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.show()

# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {}  # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# 每隔几帧渲染一次分割结果
vis_frame_stride = 1
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.show()

# 设定需要进一步细化的帧索引和对象ID
ann_frame_idx = 15  # 需要细化的帧索引
ann_obj_id = 1  # 与我们交互的对象的唯一ID(可以是任何整数)

# 显示细化前的分割结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx} -- before refinement")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
plt.show()  # 确保图像在细化前显示

# 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
points = np.array([[82, 415]], dtype=np.float32)
# 标签为`1`表示正点击,`0`表示负点击
labels = np.array([0], np.int32)
_, _, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 显示细化后的分割结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx} -- after refinement")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
plt.show()  # 确保图像在细化后显示


可以让gpt写一个分割结果整合成视频的代码,方便使用。

例2见http://t.csdnimg.cn/tL4cL

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值