参考:1.B站大学
2.ONNX官方github
2.ONNX官网
1.ONNX简介
ONNX(Open Neural Network Exchange)是一种开放格式,旨在表示机器学习模型。通俗来讲其表示一种统一的中间结构,如下图所示,可以将Pytorch、TensorFlow等模型训练框架(这些框架更适合做模型训练,不适合模型的部署推理)训练的模型转换为ONNX格式模型,然后再使用推理框架(如onnxruntime、OpenVINO、TensorRT等)进行推理,其推理速度相比于使用模型训练框架进行推理会快很多(几倍到几十倍不等)。
如果没有ONNX作为中间格式进行表示,那么一堆各种框架训练的模型和不同硬件就需要很多种适配,有了ONNX作为中间表示,各种框架训练的模型和不同硬件就只需要一种适配,即适配onnx的转换则可。
此外,还需要说明的是将训练好的模型转换成ONNX格式后其中不仅存储了神经网络的结构,还存储了模型的权重信息,使得复杂的神经网络架构可以变成一个十分简洁的文件进行表示(转换后变成.onnx格式文件)。
2.ONNX相关库
1.onnx
安装库命令
pip install onnx
onnx库是用于将训练框架的模型导出为ONNX格式
2.onnxruntime(CPU推理)
安装库命令
pip install onnxruntime
onnxruntime库是用于推理ONNX格式的模型,只能进行CPU推理,无法使用GPU推理。
3.onnxruntime-gpu(GPU推理)
安装库命令
pip install onnxruntime-gpu
onnxruntime-gpu库可使用GPU推理ONNX格式的模型。此外,需要值得注意的是onnxruntime-gpu库和onnxruntime库不能同时安装。
3.实际使用
接下来进行实战:,具体为:
import torch
from torchvision import models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model = models.resnet18(pretrained=True)
model = model.eval().to(device) #将模型设置为评估模式
x = torch.randn(1,3,256,256).to(device) # 模拟输入图片
output = model(x)
print(output.shape)
with torch.no_grad():
torch.onnx.export(
model, # 需要转换的模型
x, # 模型任意一组输入
'resnet18_imagenet.onnx', # 导出的onnx文件名
opset_version=11, # ONNX算子版本
input_names=['input'], # 输入Tensor的名称(自己起名字)
output_names=['output'] # 输出Tensor的名称(自己起名字)
)
然后使用Netron工具(下载地址,打开后window系统下载Netron-Setup-7.5.8.exe,然后直接安装即可)打开转换后的resnet18_imagenet.onnx文件,打开后显示如下:
onnxruntime推理onnx格式模型
import onnx
import onnxruntime
import cv2
import numpy as np
import torch
class ResNet18():
def __init__(self,onnxpath):
self.onnx_session=onnxruntime.InferenceSession(onnxpath)
self.input_name=self.get_input_name()
self.output_name=self.get_output_name()
#-------------------------------------------------------
# 获取输入输出的名字
#-------------------------------------------------------
def get_input_name(self):
input_name=[]
for node in self.onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_output_name(self):
output_name=[]
for node in self.onnx_session.get_outputs():
output_name.append(node.name)
return output_name
#-------------------------------------------------------
# 输入图像
#-------------------------------------------------------
def get_input_feed(self,img_tensor):
input_feed={}
for name in self.input_name:
input_feed[name]=img_tensor
return input_feed
#-------------------------------------------------------
# 1.cv2读取图像并resize
# 2.图像转BGR2RGB和HWC2CHW
# 3.图像归一化
# 4.图像增加维度
# 5.onnx_session 推理
#-------------------------------------------------------
def inference(self,img):
# img=cv2.imread(img_path)
or_img=cv2.resize(img,(256,256))
img=or_img[:,:,::-1].transpose(2,0,1) #BGR2RGB和HWC2CHW
img=img.astype(dtype=np.float32)
img/=255.0
img=np.expand_dims(img,axis=0)
input_feed=self.get_input_feed(img)
pred=self.onnx_session.run(None,input_feed)[0]
return pred,or_img
def main():
#读取ONNX模型
onnx_model = onnx.load('resnet18_imagenet.onnx')
#检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print("Success!")
onnx_path = 'resnet18_imagenet.onnx'
model = ResNet18(onnx_path)
img=cv2.imread('cat.jpg')
output,org = model.inference(img)
print(output.shape)
if __name__=="__main__":
main()
结果:
如有错误欢迎指正,同时也感谢上述参考资料作者的贡献!