import os
import onnx
from PIL import Image
from torchvision import transforms
import copy
import numpy as np
import logging
import onnxruntime
from collections import OrderedDict
from onnx import shape_inference
logging.basicConfig(level=logging.INFO)
from onnx import shape_inference, TensorProto, version_converter, numpy_helper
logger = logging.getLogger("[ONNXOPTIMIZER]")
def test_model_by_onnxruntime(model):
logger.info("Test model by onnxruntime")
# genarate random number
input_shape = model.graph.input[0].type.tensor_type.shape.dim
image_shape = [x.dim_value for x in input_shape]
image_shape_new = []
for x in image_shape:
if x == 0:
image_shape_new.append(1)
else:
image_shape_new.append(x)
image_shape = image_shape_new
img_array = np.array(np.random.random(image_shape), dtype = np.float32)
img = img_array
# load image
# input_image = Image.open("./images/ILSVRC2012_val_00000001_n01751748.JPEG")
# preprocess = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# input_tensor = preprocess(input_image)
# input_batch = input_tensor.unsqueeze(0)
# img = input_batch.numpy()
for node in model.graph.node:
for output in node.output:
model.graph.output.extend([onnx.ValueInfoProto(name=output)])
ort_session = onnxruntime.InferenceSession(model.SerializeToString(),
providers=['CUDAExecutionProvider'])
ort_inputs = {}
for i, input_ele in enumerate(ort_session.get_inputs()):
ort_inputs[input_ele.name] = img
outputs = [x.name for x in ort_session.get_outputs()]
ort_outs = ort_session.run(outputs, ort_inputs)
# ort_outs = OrderedDict(zip(outputs, ort_outs))
for i in range(len(ort_outs)):
ort_outs[i] = ort_outs[i].astype(np.float16)
ort_outs[i].tofile("./layers_result/layer_"+str(i)+".bin")
logger.info("Test model by onnxruntime success")
return ort_outs
onnx_model = onnx.load("mobilenet_v2.onnx")
ort_outs = test_model_by_onnxruntime(onnx_model)
onnxruntime 导出网络每层数据
最新推荐文章于 2023-07-01 11:26:43 发布
该代码段展示了如何利用ONNXRuntime对加载的ONNX模型进行测试,生成随机输入数据,执行模型推理,并将输出结果保存为二进制文件。过程中涉及了ONNX模型的输入输出处理、ONNXRuntime会话创建以及模型的运行。
摘要由CSDN通过智能技术生成