import cv2
import numpy as np
import onnxruntime as rt
import time
# 1. 预处理函数
def preprocess_image(image, target_size=(640, 640)):
image_resized = cv2.resize(image, target_size)
image_normalized = image_resized.astype(np.float32) / 255.0
image_transposed = np.transpose(image_normalized, [2, 0, 1])
image_batch = np.expand_dims(image_transposed, axis=0)
return image_batch
# 2. 后处理函数
def postprocess_detections(detections, original_image, confidence_threshold=0.5):
"""
处理检测结果并在图像上绘制边界框和类别标签。
"""
image_height, image_width = original_image.shape[:2]
for detection in detections:
# 提取类别置信度和类别索引
class_scores = detection[5:]
confidence = np.max(class_scores)
class_id = np.argmax(class_scores)
# 忽略低置信度的检测结果
if confidence < confidence_threshold:
continue
# 提取边界框的中心坐标和尺寸
cx, cy, bw, bh = detection[:4]
# 将中心坐标和尺寸转换为边界框的左上角和右下角坐标
xmin = int((cx - 0.5 * bw) * image_width)
ymin = int((cy - 0.5 * bh) * image_height)
xmax = int((cx + 0.5 * bw) * image_width)
ymax = int((cy + 0.5 * bh) * image_height)
# 绘制边界框
cv2.rectangle(original_image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
# 准备标签文本
label = f"{class_id}: {confidence:.2f}"
# 计算文本尺寸
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# 绘制文本背景矩形
cv2.rectangle(original_image, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (255, 0, 0), -1)
# 绘制文本
cv2.putText(original_image, label, (xmin, ymin - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return original_image
# 3. 视频推理函数
def infer_video(model_path, video_path, output_path, confidence_threshold=0.2):
# 加载 ONNX 模型
session = rt.InferenceSession(model_path)
# 获取模型的输入输出名称
input_names = [input.name for input in session.get_inputs()]
output_names = [output.name for output in session.get_outputs()]
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Unable to open video file.")
return
# 获取视频的帧率和分辨率
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 设置视频输出
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
prev_time = time.time() # 初始化时间用于FPS计算
frame_count = 0 # 记录帧数
while True:
ret, frame = cap.read()
if not ret:
break
# 记录处理开始时间
start_time = time.time()
# 预处理图像
image_batch = preprocess_image(frame)
# 推理
outputs = session.run(output_names, {session.get_inputs()[0].name: image_batch})
detections = np.squeeze(outputs[0])
# 后处理并绘制检测结果
frame_with_detections = postprocess_detections(detections, frame, confidence_threshold)
# 计算FPS
elapsed_time = time.time() - start_time
fps_display = 1.0/elapsed_time
cv2.putText(frame_with_detections, f"FPS: {fps_display:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
# 显示图像
cv2.imshow('frame', frame_with_detections)
# 写入输出视频
out.write(frame_with_detections)
# 检查按键是否为 'q',如果是,则退出视频播放
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()
print(f"Video saved to {output_path}")
if __name__ == "__main__":
model_path = "./weights/rtdetr-l.onnx"
video_path = "1.mp4"
output_path = "1_output.mp4"
infer_video(model_path, video_path, output_path)