样例一已经在另一篇写过http://t.csdnimg.cn/ykThr
本篇直接从样例二开始
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from sam2.build_sam import build_sam2_video_predictor
sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
#video_dir = r"E:\segment-anything-2-main\jc-imgs"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
plt.show()
例 2:同时分割多个对象
注意:如果您之前使用此 inference_state 运行过任何跟踪,请先通过 reset_state 重置它。
predictor.reset_state(inference_state)
步骤 1:在一个画面上添加两个对象
SAM 2 还可以同时分割和跟踪两个或多个对象。当然,一种方法是逐个进行。不过,更有效的方法是将它们放在一起批处理(例如,这样我们就可以共享对象之间的图像特征,从而降低计算成本)。
这次,我们将重点放在对象部分,并对视频中两个孩子的衬衣进行分割。在此,我们为这两个对象添加提示,并为每个对象分配一个唯一的对象 ID。
prompts = {} # 用于保存我们添加的所有点击以便可视化
在第 0 帧的 (x, y) = (200, 300) 处以正点击添加第一个对象(左边孩子的衬衫)。
我们将其赋值给对象 id 2(可以是任意整数,只需每个要跟踪的对象都是唯一的),并将其传递给 add_new_points API,以区分我们正在点击的对象。
ann_frame_idx = 0 # 交互的帧索引
ann_obj_id = 2 # 给每个交互对象一个唯一的ID(可以是任何整数)
# 在第一个对象上添加一个正点击 (x, y) = (200, 300)
points = np.array([[200, 300]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)
# 将点击信息保存到 prompts 字典中
prompts[ann_obj_id] = points, labels
# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# 在当前(交互)帧上显示结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()
嗯,这次我们只想选择孩子的衬衫,但模型预测整个孩子都要戴面具。让我们在 (x, y) = (275, 175) 处点击一个负值来完善预测。
在第二次负点击后,我们得到的第一个对象是左边孩子的衬衫。
在此,我们为第二个对象分配对象 ID 3(可以是任意整数,只需每个要跟踪的对象都是唯一的)。
注意:当有多个对象时,add_new_points API 将返回每个对象的掩码列表。
ann_frame_idx = 0 # 交互的帧索引
ann_obj_id = 3 # 给每个交互对象一个唯一的ID(可以是任何整数)
# 现在我们开始处理第二个对象(给予对象ID `3`),并在 (x, y) = (400, 150) 处添加一个正点击
points = np.array([[400, 150]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels
# `add_new_points` 返回所有对象在当前交互帧上的掩码
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# 显示当前帧上所有对象的结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()
这次,只需点击一下,模型就能预测出我们要跟踪的衬衫的面罩。漂亮
第二步:传播提示,在整个视频中获取掩码
现在,我们在整个视频中传播两个对象的提示,以获取它们的掩码。
注意:当有多个对象时,propagate_in_video API 将返回每个对象的掩码列表。
# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {} # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# 每隔几帧渲染一次分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
plt.figure(figsize=(6, 4))
plt.title(f"frame {out_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
plt.show()
设备问题,速度比较慢
在这段视频中,两件儿童衬衫都被很好地分割开来。
现在,您可以在自己的视频和用例中尝试使用 SAM 2!
完整代码:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from sam2.build_sam import build_sam2_video_predictor
sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
#video_dir = r"E:\segment-anything-2-main\jc-imgs"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
plt.show()
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
prompts = {} # 用于保存我们添加的所有点击以便可视化
# 添加第一个对象
ann_frame_idx = 0 # 交互的帧索引
ann_obj_id = 2 # 给每个交互对象一个唯一的ID(可以是任何整数)
# 添加一个第二个负点击 (x, y) = (275, 175) 以优化第一个对象
# 将所有点击(及其标签)发送到 `add_new_points`
points = np.array([[200, 300], [275, 175]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1, 0], np.int32)
# 将点击信息保存到 prompts 字典中
prompts[ann_obj_id] = points, labels
# 向预测器添加新点
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
ann_frame_idx = 0 # 交互的帧索引
ann_obj_id = 3 # 给每个交互对象一个唯一的ID(可以是任何整数)
# 现在我们开始处理第二个对象(给予对象ID `3`),并在 (x, y) = (400, 150) 处添加一个正点击
points = np.array([[400, 150]], dtype=np.float32)
# 对于 labels,`1` 表示正点击,`0` 表示负点击
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels
# `add_new_points` 返回所有对象在当前交互帧上的掩码
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# 显示当前帧上所有对象的结果
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
plt.show()
# 在整个视频中运行传播并将结果收集到一个字典中
video_segments = {} # video_segments 包含每帧的分割结果
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# 每隔几帧渲染一次分割结果
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
plt.figure(figsize=(6, 4))
plt.title(f"frame {out_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
plt.show()