onnx_yolov5_inference.py
import os
import cv2
import time
import onnxruntime as ort
import numpy as np
from common import pad_to_square
from common import save_tensor
from common import decode
from common import non_max_suppression
from common import draw_result
anchors = [[10,13], [16,30], [33,23],
[30,61], [62,45], [59,119],
[116,90],[156,198], [373,326]]
# 模型输入尺寸
model_input_h = 640
model_input_w = 640
# 目标置信度阈值
obj_conf_thresh = 0.3
# nms iou阈值
nms_iou_thresh = 0.5
# 类别数量
classes_num = 1
# 类别标签
class_label = ('o')
if __name__ == '__main__':
image_name = 'B5D95V.jpg'
test_image_path = './data/test_images/{}'.format(image_name)
onnx_model_path = './models/onnx/best_20210531_sim.onnx'
tensor_output_dir = './data/output_tensor'
# 读取图片并转换到RGB图像格式,
# 注意opencv图像维度顺序为:
# shape[0]=h, shape[1]=w, shape[2]=c
bgr_image = cv2.imread(test_image_path)
if bgr_image is None:
print('load image {} failed !!!'.format(test_image_path))
exit(-1)
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
# 图像尺寸是否与模型输入尺寸一致,如果不一致则进行填充、缩放
if bgr_image.shape != (model_input_h, model_input_w, 3):
rgb_image = pad_to_square(rgb_image)
rgb_image = cv2.resize(rgb_image, (model_input_w, model_input_h))
# 图像数据类型转换为浮点类型,保证模型的输入为浮点类型
rgb_image = rgb_image.astype(np.float32)
# 归一化
rgb_image = rgb_image / 255.0
# 维度变换:HWC => CHW
rgb_image = rgb_image.transpose(2, 0, 1)
# 增加一个数据维度:CHW => NCHW
input_data = np.expand_dims(rgb_image, axis=0)
print('input_data shape = {}'.format(input_data.shape))
print('input_data type = {}'.format(input_data.dtype))
# 启动onnx runtime会话
ort_session = ort.InferenceSession(onnx_model_path)
start = time.time()
# 模型运行
model_outputs = ort_session.run(None, {'data': input_data})
end = time.time()
time_ms = (end - start)*1000
print('inference time: {} ms'.format(time_ms))
# 保存模型输出结果, 打印模型输出维度
for index, output in enumerate(model_outputs):
save_tensor(model_outputs[index], os.path.join(tensor_output_dir, 'onnx', 'output{}'.format(index)))
print('yolo output {} shape: {}'.format(index, output.shape))
# 模型输出解码
decode_outputs = decode(model_outputs=model_outputs,
model_input_h=model_input_h,
model_input_w=model_input_w,