rtdetr的onnx推理

import cv2
import numpy as np
import onnxruntime as rt
import time

# 1. 预处理函数
def preprocess_image(image, target_size=(640, 640)):
    image_resized = cv2.resize(image, target_size)
    image_normalized = image_resized.astype(np.float32) / 255.0
    image_transposed = np.transpose(image_normalized, [2, 0, 1])
    image_batch = np.expand_dims(image_transposed, axis=0)
    return image_batch

# 2. 后处理函数
def postprocess_detections(detections, original_image, confidence_threshold=0.5):
    """
    处理检测结果并在图像上绘制边界框和类别标签。
    """
    image_height, image_width = original_image.shape[:2]

    for detection in detections:
        # 提取类别置信度和类别索引
        class_scores = detection[5:]
        confidence = np.max(class_scores)
        class_id = np.argmax(class_scores)

        # 忽略低置信度的检测结果
        if confidence < confidence_threshold:
            continue

        # 提取边界框的中心坐标和尺寸
        cx, cy, bw, bh = detection[:4]

        # 将中心坐标和尺寸转换为边界框的左上角和右下角坐标
        xmin = int((cx - 0.5 * bw) * image_width)
        ymin = int((cy - 0.5 * bh) * image_height)
        xmax = int((cx + 0.5 * bw) * image_width)
        ymax = int((cy + 0.5 * bh) * image_height)

        # 绘制边界框
        cv2.rectangle(original_image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)

        # 准备标签文本
        label = f"{class_id}: {confidence:.2f}"

        # 计算文本尺寸
        (text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # 绘制文本背景矩形
        cv2.rectangle(original_image, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (255, 0, 0), -1)

        # 绘制文本
        cv2.putText(original_image, label, (xmin, ymin - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

    return original_image

# 3. 视频推理函数
def infer_video(model_path, video_path, output_path, confidence_threshold=0.2):
    # 加载 ONNX 模型
    session = rt.InferenceSession(model_path)

    # 获取模型的输入输出名称
    input_names = [input.name for input in session.get_inputs()]
    output_names = [output.name for output in session.get_outputs()]

    # 打开视频文件
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Unable to open video file.")
        return

    # 获取视频的帧率和分辨率
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # 设置视频输出
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    prev_time = time.time()  # 初始化时间用于FPS计算
    frame_count = 0  # 记录帧数

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # 记录处理开始时间
        start_time = time.time()

        # 预处理图像
        image_batch = preprocess_image(frame)

        # 推理
        outputs = session.run(output_names, {session.get_inputs()[0].name: image_batch})
        detections = np.squeeze(outputs[0])

        # 后处理并绘制检测结果
        frame_with_detections = postprocess_detections(detections, frame, confidence_threshold)

        # 计算FPS
        elapsed_time = time.time() - start_time
        fps_display = 1.0/elapsed_time
        cv2.putText(frame_with_detections, f"FPS: {fps_display:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

        # 显示图像
        cv2.imshow('frame', frame_with_detections)

        # 写入输出视频
        out.write(frame_with_detections)

        # 检查按键是否为 'q',如果是,则退出视频播放
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    print(f"Video saved to {output_path}")

if __name__ == "__main__":
    model_path = "./weights/rtdetr-l.onnx"
    video_path = "1.mp4"
    output_path = "1_output.mp4"
    infer_video(model_path, video_path, output_path)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值