import torch
from torchvision import models
net = models.resnet.resnet18(pretrained=True)
dummpy_input = torch.randn(1,3,224,224)
torch.onnx.export(net, dummpy_input, 'resnet18.onnx')
import onnx
# Load the ONNX model
model = onnx.load("resnet18.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
# print(onnx.helper.printable_graph(model.graph))
import onnxruntime as rt
import numpy as np
data = np.array(np.random.randn(1,3,224,224))
sess = rt.InferenceSession('resnet18.onnx')
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
print('input_name:',input_name,'label_name:',label_name)
pred_onx = sess.run([label_name], {input_name:data.astype(np.float32)})[0]
print(pred_onx.shape)
print(np.argmax(pred_onx))
模型部署onnx
最新推荐文章于 2024-08-23 11:07:44 发布