一、config.txt配置文件
./models/yolo11s.onnx
./datasets/img
./runs/onnx
(0,255,0)
0: person
1: bicycle
...
78: hair drier
79: toothbrush
二、代码演示
1. 在 Python>=3.8 环境中使用 PyTorch>=1.8 通过 pip 安装包含所有依赖项 的 ultralytics 包。
pip install ultralytics
2. 源代码 ort_python.py
import os
import cv2
import numpy as np
import onnxruntime as ort
import argparse
def get_args():
"""
解析命令行参数,支持通过 `cfg` 指定配置文件路径。
"""
parser = argparse.ArgumentParser(description="ONNX 图像检测脚本")
parser.add_argument('--cfg', type=str, default='config.txt', help="path to config.txt")
return parser.parse_args()
def parse_args(config_path):
config = {}
try:
with open(config_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) < 5:
raise ValueError("配置文件格式不正确,至少需要5行数据")
def remove_comment(line):
return line.split('#')[0].strip()
config['model_path'] = remove_comment(lines[0])
config['image_folder'] = remove_comment(lines[1])
config['output_folder'] = remove_comment(lines[2])
box_color_str = remove_comment(lines[3])
config['box_color'] = tuple(map(int, box_color_str[1:-1].split(',')))
config['class_mapping'] = {}
for line in lines[4:]:
line = remove_comment(line)
if line:
class_id, class_name = line.split(':')
config['class_mapping'][int(class_id)] = class_name.strip()
except Exception as e:
print(f"读取配置文件时出错: {e}")
raise
return config
def print_config(config):
"""
逐行打印输出 config 配置字典中的所有项
"""
print("读取的配置内容:")
for key, value in config.items():
if isinstance(value, dict):
print(f"{key}:")
for sub_key, sub_value in value.items():
print(f" {sub_key}: {sub_value}")
else:
print(f"{key}: {value}")
def load_model(model_path):
try:
session = ort.InferenceSession(model_path)
return session
except Exception as e:
print(f"加载模型时出错: {e}")
raise
def preprocess_image(img, new_shape, color=(114, 114, 114)):
shape = img.shape[:2]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
ratio = r, r
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw /= 2
dh /= 2
if shape[::-1] != new_unpad:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
image_resized = img.transpose(2, 0, 1) / 255.0
image_resized = np.expand_dims(image_resized, axis=0).astype(np.float32)
return image_resized, ratio, (dw, dh)
def run_inference(session, input_data):
try:
inputs = {session.get_inputs()[0].name: input_data}
outputs = session.run(None, inputs)
return outputs
except Exception as e:
print(f"运行推理时出错: {e}")
raise
def postprocess_output(output, ratio, pad_info, confidence_threshold, iou_threshold):
output = np.array(output)
bbox_coords = output[0, :4, :]
class_probs = output[0, 4:, :]
scores = np.max(class_probs, axis=0)
class_ids = np.argmax(class_probs, axis=0)
filtered_boxes = []
for i in range(len(scores)):
if scores[i] > confidence_threshold:
x_center = bbox_coords[0, i]
y_center = bbox_coords[1, i]
width = bbox_coords[2, i]
height = bbox_coords[3, i]
x1 = (x_center - width/ 2 - pad_info[0]) / ratio[0]
y1 = (y_center - height / 2 - pad_info[1]) / ratio[1]
x2 = (x_center + width / 2 - pad_info[0]) / ratio[0]
y2 = (y_center + height / 2 - pad_info[1]) / ratio[1]
filtered_boxes.append([x1, y1, x2, y2, scores[i], class_ids[i]])
if len(filtered_boxes) == 0:
print("No boxes after filtering, skipping NMS.")
return []
boxes = np.array([box[:4] for box in filtered_boxes])
scores = np.array([box[4] for box in filtered_boxes])
class_ids = np.array([box[5] for box in filtered_boxes])
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)
if indices is None:
print("No boxes passed NMS.")
return []
nms_boxes = []
for i in indices.flatten():
nms_boxes.append(filtered_boxes[i])
return nms_boxes
def save_results(image, nms_boxes, class_names, output_path, box_color):
if len(nms_boxes) == 0:
print("No valid detection boxes to visualize.")
for box in nms_boxes:
x1, y1, x2, y2, score, class_id = box
label = f'{class_names[class_id]}: {score:.2f}'
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), box_color, 2)
cv2.putText(image, label, (int(x1) + 5, int(y1) + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
cv2.imwrite(output_path, image)
def main():
args = get_args()
config_path = args.cfg
if not os.path.exists(config_path):
print(f"错误: 文件 '{config_path}' 不存在。")
return None
config = parse_args(config_path)
print_config(config)
session = load_model(config['model_path'])
input_shape = session.get_inputs()[0].shape[2:]
new_shape = (input_shape[1], input_shape[0])
image_files = [f for f in os.listdir(config['image_folder']) if f.endswith(('jpg', 'jpeg', 'png'))]
total_images = len(image_files)
for i, image_file in enumerate(image_files):
image_path = os.path.join(config['image_folder'], image_file)
if not os.path.exists(image_path):
print(f"错误: 文件 '{image_path}' 不存在。")
return None
image = cv2.imread(image_path)
if image is None:
print(f"错误: 无法读取图像文件 '{image_path}'。请检查文件格式或路径。")
return None
if not os.path.exists(config['output_folder']):
print(f"输出文件夹不存在,为您创建新的输出文件夹 '{config['output_folder']}'")
os.makedirs(config['output_folder'])
print(f"处理图像 {i + 1}/{total_images}: {image_file}")
preprocessed_image, ratio, pad_info = preprocess_image(image, new_shape)
outputs = run_inference(session, preprocessed_image)
nms_boxes = postprocess_output(outputs[0], ratio, pad_info, confidence_threshold=0.6, iou_threshold=0.4)
output_image_path = os.path.join(config['output_folder'], f"{os.path.splitext(image_file)[0]}.jpg")
save_results(image, nms_boxes, config['class_mapping'], output_image_path, config['box_color'])
if __name__ == '__main__':
main()
3. 使用
方法一:默认读取当前目录下的config.txt
python ort_python.py
方法二:读取指定目录下的config.txt
python ort_python.py --cfg=path/to/config.txt