SAM2视频模块使用(翻译自video_predictor_example.ipynb)(二)

样例一已经在另一篇写过http://t.csdnimg.cn/ykThr

本篇直接从样例二开始

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
#video_dir = r"E:\segment-anything-2-main\jc-imgs"
# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
plt.show()

例 2:同时分割多个对象
注意:如果您之前使用此 inference_state 运行过任何跟踪,请先通过 reset_state 重置它。

predictor.reset_state(inference_state)

步骤 1:在一个画面上添加两个对象
SAM 2 还可以同时分割和跟踪两个或多个对象。当然,一种方法是逐个进行。不过,更有效的方法是将它们放在一起批处理(例如,这样我们就可以共享对象之间的图像特征,从而降低计算成本)。

这次,我们将重点放在对象部分,并对视频中两个孩子的衬衣进行分割。在此,我们为这两个对象添加提示,并为每个对象分配一个唯一的对象 ID。

prompts = {}  # 用于保存我们添加的所有点击以便可视化

在第 0 帧的 (x, y) = (200, 300) 处以正点击添加第一个对象(左边孩子的衬衫)。

我们将其赋值给对象 id 2(可以是任意整数,只需每个要跟踪的对象都是唯一的),并将其传递给 add_new_points API,以区分我们正在点击的对象。

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 2  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 在第一个对象上添加一个正点击 (x, y) = (200, 300)
points = np.array([[200, 300]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)

# 将点击信息保存到 prompts 字典中
prompts[ann_obj_id] = points, labels

# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 在当前(交互)帧上显示结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()

嗯,这次我们只想选择孩子的衬衫,但模型预测整个孩子都要戴面具。让我们在 (x, y) = (275, 175) 处点击一个负值来完善预测。

在第二次负点击后,我们得到的第一个对象是左边孩子的衬衫。

在此,我们为第二个对象分配对象 ID 3(可以是任意整数,只需每个要跟踪的对象都是唯一的)。

注意:当有多个对象时,add_new_points API 将返回每个对象的掩码列表。

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 3  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 现在我们开始处理第二个对象(给予对象ID `3`),并在 (x, y) = (400, 150) 处添加一个正点击
points = np.array([[400, 150]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels

# `add_new_points` 返回所有对象在当前交互帧上的掩码
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 显示当前帧上所有对象的结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()

这次,只需点击一下,模型就能预测出我们要跟踪的衬衫的面罩。漂亮

第二步:传播提示,在整个视频中获取掩码
现在,我们在整个视频中传播两个对象的提示,以获取它们的掩码。

注意:当有多个对象时,propagate_in_video API 将返回每个对象的掩码列表。

# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {}  # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# 每隔几帧渲染一次分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.show()

设备问题,速度比较慢 

 

在这段视频中,两件儿童衬衫都被很好地分割开来。

现在,您可以在自己的视频和用例中尝试使用 SAM 2!

完整代码:

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
#video_dir = r"E:\segment-anything-2-main\jc-imgs"
# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
plt.show()

inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

prompts = {}  # 用于保存我们添加的所有点击以便可视化

# 添加第一个对象
ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 2  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 添加一个第二个负点击 (x, y) = (275, 175) 以优化第一个对象
# 将所有点击(及其标签)发送到 `add_new_points`
points = np.array([[200, 300], [275, 175]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1, 0], np.int32)

# 将点击信息保存到 prompts 字典中
prompts[ann_obj_id] = points, labels

# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

ann_frame_idx = 0  # 交互的帧索引
ann_obj_id = 3  # 给每个交互对象一个唯一的ID(可以是任何整数)

# 现在我们开始处理第二个对象(给予对象ID `3`),并在 (x, y) = (400, 150) 处添加一个正点击
points = np.array([[400, 150]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels

# `add_new_points` 返回所有对象在当前交互帧上的掩码
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 显示当前帧上所有对象的结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()

# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {}  # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# 每隔几帧渲染一次分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    plt.show()




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值