7月29日,META发布了Segment Anything Model 2 (SAM 2) ,一个通用的视觉分割系统,它不仅适用于图像,也适用于视频。它是对 Segment Anything (SA) 模型的升级,后者主要针对图像中的可提示分割。考虑到现实世界的视觉片段在视频中展现出的复杂动态特性,以及多媒体内容中视频数据的日益增长,因此SAM 2 提供了一种能够同时处理图像和视频的统一模型。
SAM 2 引入了Promptable Visual Segmentation (PVS)任务,将图像分割的概念推广到了视频领域。它可以接受视频中任何帧上的点、框或掩码作为输入,定义要分割的目标区域,并预测其时空范围——即所谓的‘masklet’。一旦预测出masklet,就可以通过在额外帧中提供prompt来迭代细化分割结果。
SAM 2 的关键创新之一是引入了流式内存模块,这使得模型能够存储目标对象的信息和之前的交互,从而允许它在视频中生成masklet预测,并且基于先前观察到的帧中存储的对象记忆上下文可以有效地修正这些预测。而当应用于图像时,由于没有历史记忆,模型的行为类似于原来的SAM。
此外,SAM 2 在视频分割任务中,与之前的模型相比,只需要三分之一的用户交互就能达到更好的精度。而在图像分割任务中,SAM 2 比 SAM 更准确且运行速度提高了六倍。这些改进归功于 SAM 2 中更小但更高效的核心图像编码器 Hiera,以及在混合了静态图像和视频的大规模数据集上的训练。
Demo网址:https://sam2.metademolab.com
代码仓库:https://github.com/facebookresearch/segment-anything-2
官方网站:https://ai.meta.com/sam2
根据meta提供的sam2 的demo,我们也对实际的效果了初步的体验,步骤也比较简单,只需clone git仓库并搭建好虚拟环境即可
git clone git@github.com:facebookresearch/segment-anything-2.git
cd segment-anything-2; pip install -e .
使用SAM 2进行图像分割
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.float16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
image = Image.open('notebooks/images/classroom.jpg')
image = np.array(image.convert("RGB"))
导入模型,并且指定坐标点
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(image)
input_point = np.array([[535, 354]])
input_label = np.array([1])
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, #是否输出多个mask
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
输出mask
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, #是否输出多个mask
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
# show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
# plt.axis('on')
plt.show()
可以看出用一个点还是会出现分割不清的状况,这个时候可以添加第一个坐标点来进行更进一步的分割
nput_point = np.array([[535, 354], [558, 100]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
plt.show()
通过二次添加坐标点,可以很好地分割出想要的对象。
使用SAM 2进行视频图像分割,首先是导入所有视频帧,然后通过第一帧图像进行交互,进行坐标点指定
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
#Solved by building extension: python setup.py build_ext --inplace
#cannot import name '_C' from 'sam2'
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.float16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from sam2.build_sam import build_sam2_video_predictor
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def store_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
img_pil = Image.fromarray(mask_image,mode='RGBA')
img_pil.save("image_{}.png".format(obj_id))
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# Let's add a positive click at (x, y) = (210, 350) to get started
points = np.array([[413, 256]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# Let's add a 2nd positive click at (x, y) = (250, 220) to refine the mask
# sending all clicks (and their labels) to `add_new_points`
points = np.array([[389, 119], [413, 256]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
plt.show()
指定一个坐标点,可能达不到我们想要的分割目标,
和图像分割类似,我们进行二次坐标指定
可以看到,通过二次指定可以很好的达到想要的分割效果,然后就是逐帧视频进行分割了,将分割后的图像保存,然后再转换成视频可以看看具体效果
# run propagation throughout the video and collect the results in a dict
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# render the segmentation results every few frames
vis_frame_stride = 30
plt.close("all")
for out_frame_idx in range(0, len(frame_names)):
plt.figure(figsize=(6, 4))
# plt.title(f"frame {out_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
plt.axis('off')
folder_path = "output_frame"
framename = "{}.png".format(out_frame_idx)
frame_path = folder_path + "/" + framename
plt.savefig(frame_path,transparent=True,bbox_inches='tight', pad_inches=0)
可以看到,每一帧都能够准确地识别分割
然后我们通过opencv将所有图像重组成视频
import cv2
import os
# 图片路径
image_folder = 'output_frame'
video_name = 'video_output.mp4'
# 获取第一张图片的尺寸
images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape
# 视频的帧率
fps = 30
# 定义编码器和创建VideoWriter对象
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'XVID'
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
# 将图片添加到视频中
for image in images:
video.write(cv2.imread(os.path.join(image_folder, image)))
# 释放VideoWriter
video.release()
最终呈现效果如下: