项目地址:
https://github.com/PeterL1n/RobustVideoMatting
文章:
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
PyTorch、TensorFlow、TensorFlow中的强大视频抠图功能。js,ONNX,CoreML!
稳定视频抠像 (RVM)
论文 Robust High-Resolution Video Matting with Temporal Guidance 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 4K 76FPS 和 HD 104FPS。此研究项目来自字节跳动。
展示视频
观看展示视频 (YouTube, Bilibili),了解模型能力。
视频中的所有素材都提供下载,可用于测试模型:Google Drive
Demo
下载
推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读推断文档。
框架 | 下载 | 备注 |
PyTorch | rvm_mobilenetv3.pth rvm_resnet50.pth | 官方 PyTorch 模型权值。文档 |
TorchHub | 无需手动下载。 | 更方便地在你的 PyTorch 项目里使用此模型。文档 |
TorchScript | rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript | 若需在移动端推断,可以考虑自行导出 int8 量化的模型。文档 |
ONNX | rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx | 在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。文档,导出 |
TensorFlow | rvm_mobilenetv3_tf.zip rvm_resnet50_tf.zip | TensorFlow 2 SavedModel 格式。文档 |
TensorFlow.js | rvm_mobilenetv3_tfjs_int8.zip | 在网页上跑模型。展示,示范代码 |
CoreML | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel | CoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。s 代表下采样比。文档,导出 |
所有模型可在 Google Drive 或百度网盘(密码: gym7)上下载。
PyTorch 范例
- 1 安装 Python 库:
pip install -r requirements_inference.txt
- 2 加载模型:
import torch
from model import MattingNetwork
model = MattingNetwork('mobilenetv3').eval().cuda() # 或 "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
- 3 若只需要做视频抠像处理,我们提供简单的 API:
from inference import convert_video
convert_video(
model, # 模型,可以加载到任何设备(cpu 或 cuda)
input_source='input.mp4', # 视频文件,或图片序列文件夹
output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列)
output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径
output_alpha="pha.mp4", # [可选项] 输出透明度预测
output_foreground="fgr.mp4", # [可选项] 输出前景预测
output_video_mbps=4, # 若导出视频,提供视频码率
downsample_ratio=None, # 下采样比,可根据具体视频调节,或 None 选择自动
seq_chunk=12, # 设置多帧并行计算
)
- 4 或自己写推断逻辑:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter
reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景
rec = [None] * 4 # 初始循环记忆(Recurrent States)
downsample_ratio = 0.25 # 下采样比,根据视频调节
with torch.no_grad():
for src in DataLoader(reader): # 输入张量,RGB通道,范围为 0~1
fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # 将上一帧的记忆给下一帧
com = fgr * pha + bgr * (1 - pha) # 将前景合成到绿色背景
writer.write(com) # 输出帧
- 5 模型和 API 也可通过 TorchHub 快速载入。
# 加载模型
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50"
# 转换 API
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
推断文档里有对 downsample_ratio
参数,API 使用,和高阶使用的讲解。
训练和评估
请参照训练文档(英文)。
速度
速度用 inference_speed_test.py
测量以供参考。
GPU | dType | HD (1920x1080) | 4K (3840x2160) |
---|---|---|---|
RTX 3090 | FP16 | 172 FPS | 154 FPS |
RTX 2060 Super | FP16 | 134 FPS | 108 FPS |
GTX 1080 Ti | FP32 | 104 FPS | 74 FPS |
- 注释1:HD 使用
downsample_ratio=0.25
,4K 使用downsample_ratio=0.125
。 所有测试都使用 batch size 1 和 frame chunk 1。 - 注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。
- 注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 PyNvCodec。
复现使用
知乎有个大佬把它分别用python和C++复现了RobustVideoMatting🔥2021 ONNXRuntime C++工程化记录-实现篇 - 知乎
python代码:
import cv2
import time
import argparse
import numpy as np
import onnxruntime as ort
def normalize(frame: np.ndarray) -> np.ndarray:
"""
Args:
frame: BGR
Returns: normalized 0~1 BCHW RGB
"""
img = frame.astype(np.float32).copy() / 255.0
img = img[:, :, ::-1] # RGB
img = np.transpose(img, (2, 0, 1)) # (C,H,W)
img = np.expand_dims(img, axis=0) # (B=1,C,H,W)
return img
def infer_rvm_frame(weight: str = "rvm_resnet50_fp32.onnx",
img_path: str = "test.jpg",
output_path: str = "test_onnx.jpg"):
sess = ort.InferenceSession(f'./checkpoint/{weight}')
print(f"Load checkpoint/{weight} done!")
for _ in sess.get_inputs():
print("Input: ", _)
for _ in sess.get_outputs():
print("Input: ", _)
frame = cv2.imread(img_path)
src = normalize(frame)
rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4 # 必须用模型一样的 dtype
downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32
bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1))
fgr, pha, *rec = sess.run([], {
'src': src,
'r1i': rec[0],
'r2i': rec[1],
'r3i': rec[2],
'r4i': rec[3],
'downsample_ratio': downsample_ratio
})
merge_frame = fgr * pha + bgr * (1. - pha) # (1,3,H,W)
merge_frame = merge_frame[0] * 255. # (3,H,W)
merge_frame = merge_frame.astype(np.uint8) # RGB
merge_frame = np.transpose(merge_frame, (1, 2, 0)) # (H,W,3)
merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB)
cv2.imwrite(output_path, merge_frame)
print(f"infer done! saved {output_path}")
def infer_rvm_video(weight: str = "rvm_resnet50_fp32.onnx",
video_path: str = "./demo/1917.mp4",
output_path: str = "./demo/1917_onnx.mp4"):
sess = ort.InferenceSession(f'./checkpoint/{weight}')
print(f"Load checkpoint/{weight} done!")
for _ in sess.get_inputs():
print("Input: ", _)
for _ in sess.get_outputs():
print("Input: ", _)
# 读取视频
video_capture = cv2.VideoCapture(video_path)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video Caputer: Height: {height}, Width: {width}, Frame Count: {frame_count}")
# 写出视频
fps = 25
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
print(f"Create Video Writer: {output_path}")
i = 0
rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4 # 必须用模型一样的 dtype
downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32
bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1))
print(f"Infer {video_path} start ...")
while video_capture.isOpened():
success, frame = video_capture.read()
if success:
i += 1
src = normalize(frame)
# src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype
t1 = time.time()
fgr, pha, *rec = sess.run([], {
'src': src,
'r1i': rec[0],
'r2i': rec[1],
'r3i': rec[2],
'r4i': rec[3],
'downsample_ratio': downsample_ratio
})
t2 = time.time()
print(f"Infer {i}/{frame_count} done! -> cost {(t2 - t1) * 1000} ms", end=" ")
merge_frame = fgr * pha + bgr * (1. - pha) # (1,3,H,W)
merge_frame = merge_frame[0] * 255. # (3,H,W)
merge_frame = merge_frame.astype(np.uint8) # RGB
merge_frame = np.transpose(merge_frame, (1, 2, 0)) # (H,W,3)
merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB)
merge_frame = cv2.resize(merge_frame, (width, height))
video_writer.write(merge_frame)
print(f"write {i}/{frame_count} done.")
else:
print("can not read video! skip!")
break
video_capture.release()
video_writer.release()
print(f"Infer {video_path} done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default="video")
parser.add_argument("--weight", type=str, default="rvm_resnet50_fp32.onnx")
parser.add_argument("--input", type=str, default="./demo/1917.mp4")
parser.add_argument("--output", type=str, default="./demo/1917_onnx.mp4")
args = parser.parse_args()
if args.mode == "video":
infer_rvm_video(weight=args.weight, video_path=args.input, output_path=args.output)
else:
infer_rvm_frame(weight=args.weight, img_path=args.input, output_path=args.output)
"""
rvm_resnet50_fp32.onnx
rvm_mobilenetv3_fp32.onnx
PYTHONPATH=. python3 ./inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4
PYTHONPATH=. python3 ./inference_onnx.py --mode img --input test.jpg --output test_onnx.jpg
python inference_onnx.py --input ./demo/1917.mp4 --output ./demo/1917_onnx.mp4
"""