import cv2 # 导入OpenCV用于视频处理
import time # 导入time用于测量执行时间
import numpy as np # 导入NumPy用于数值计算
import onnxruntime as ort # 导入ONNX Runtime用于运行模型
import os # 导入OS用于文件和目录操作
from moviepy.editor import VideoFileClip # 从moviepy导入VideoFileClip用于视频编辑
def normalize(frame: np.ndarray) -> np.ndarray:
"""
对帧进行归一化处理以供模型推理。
- 将帧转换为float32并缩放到[0, 1]范围
- 将BGR转换为RGB
- 转置为(C, H, W)形状
- 添加批量维度,变成(1, C, H, W)
"""
img = frame.astype(np.float32).copy() / 255.0 # 缩放像素值到[0, 1]
img = img[:, :, ::-1] # 将BGR转换为RGB
img = np.transpose(img, (2, 0, 1)) # 改变形状为(C, H, W)
img = np.expand_dims(img, axis=0) # 添加批量维度: (1, C, H, W)
return img
def infer_rvm_video(weight: str, video_path: str, output_path: str):
"""
使用ONNX模型处理视频文件并保存处理后的视频。
- 加载ONNX模型
- 读取视频文件
- 用模型处理每一帧
- 将结果保存到临时视频文件
- 将原视频的音频添加到处理后的视频中
- 保存最终输出视频
"""
sess = ort.InferenceSession(weight) # 加载ONNX模型
print(f"Load checkpoint/{weight} done!") # 打印模型加载确认信息
video_capture = cv2.VideoCapture(video_path) # 打开视频文件进行读取
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) # 获取视频宽度
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 获取视频高度
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
fps = int(video_capture.get(cv2.CAP_PROP_FPS)) # 获取视频帧率
print(f"Video Capture: Height: {height}, Width: {width}, Frame Count: {frame_count}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 定义视频写入的编解码器
temp_video_path = "temp_video.mp4" # 临时处理视频的路径
video_writer = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) # 初始化视频写入器
print(f"Create Video Writer: {temp_video_path}")
i = 0 # 帧索引
rec = [np.zeros([1, 1, 1, 1], dtype=np.float32)] * 4 # 初始化递归缓存
downsample_ratio = np.array([0.25], dtype=np.float32) # 下采样比率
bgr = np.array([0.47, 1., 0.6]).reshape((3, 1, 1)) # 背景颜色的BGR格式
print(f"Infer {video_path} start ...")
while video_capture.isOpened(): # 遍历视频中的每一帧
success, frame = video_capture.read() # 读取一帧
if success:
i += 1 # 增加帧索引
src = normalize(frame) # 归一化帧
t1 = time.time() # 推理开始时间
fgr, pha, *rec = sess.run([], { # 执行推理
'src': src,
'r1i': rec[0],
'r2i': rec[1],
'r3i': rec[2],
'r4i': rec[3],
'downsample_ratio': downsample_ratio
})
t2 = time.time() # 推理结束时间
print(f"Infer {i}/{frame_count} done! -> cost {(t2 - t1) * 1000} ms", end=" ")
merge_frame = fgr * pha + bgr * (1. - pha) # 合成帧
merge_frame = merge_frame[0] * 255. # 还原到[0, 255]范围
merge_frame = merge_frame.astype(np.uint8) # 转换为uint8
merge_frame = np.transpose(merge_frame, (1, 2, 0)) # 改变形状为(H, W, C)
merge_frame = cv2.cvtColor(merge_frame, cv2.COLOR_BGR2RGB) # 将BGR转换为RGB
merge_frame = cv2.resize(merge_frame, (width, height)) # 调整为原始尺寸
video_writer.write(merge_frame) # 将帧写入视频
print(f"write {i}/{frame_count} done.")
else:
print("can not read video! skip!") # 处理读取错误
break
video_capture.release() # 释放视频捕捉
video_writer.release() # 释放视频写入器
print(f"Infer {video_path} done!")
# 使用moviepy添加音频
original_clip = VideoFileClip(video_path) # 加载原始视频剪辑
processed_clip = VideoFileClip(temp_video_path) # 加载处理后的视频剪辑
final_clip = processed_clip.set_audio(original_clip.audio) # 将处理后剪辑的音频设置为原始音频
final_clip.write_videofile(output_path, codec='libx264') # 保存最终视频及音频
os.remove(temp_video_path) # 删除临时视频文件
def process_videos_in_directory(weight: str, input_dir: str, output_dir: str):
"""
处理目录中的所有视频文件。
- 如果输出目录不存在,则创建
- 处理输入目录中的每个.mp4视频
- 将处理后的视频保存到输出目录
"""
if not os.path.exists(output_dir): # 检查输出目录是否存在
os.makedirs(output_dir) # 如果不存在,则创建输出目录
for filename in os.listdir(input_dir): # 遍历输入目录中的所有文件
if filename.endswith(".mp4"): # 仅处理.mp4文件
input_path = os.path.join(input_dir, filename) # 构造输入文件的完整路径
output_path = os.path.join(output_dir, f"processed_{filename}") # 构造输出文件的完整路径
infer_rvm_video(weight, input_path, output_path) # 处理视频
if __name__ == "__main__":
weight = r'' # ONNX模型的路径
input_dir = r'' # 输入视频文件的目录
output_dir = r'' # 保存处理后视频文件的目录
process_videos_in_directory(weight=weight, input_dir=input_dir, output_dir=output_dir) # 处理视频
在视频中抠取人物视频,可换背景(进行绿幕处理)
于 2024-09-10 11:31:03 首次发布