环境依赖
Python环境依赖
CPU版本推理:onnxruntime
GPU版本推理:onnxruntime-gpu
torchvision
PIL
torch
netron
模型转化
把 PyTorch 模型转换成 ONNX 模型时,我们往往只需要轻松地调用一句torch.onnx.export就可以了。这个函数的接口看上去简单,但它在使用上还有着诸多的注意事项。
前三个必选参数为模型、模型输入、导出的 onnx 文件名,我们对这几个参数已经很熟悉了。我们来着重看一下后面的一些常用可选参数。
- export_params
模型中是否存储模型权重。一般中间表示包含两大类信息:模型结构和模型权重,这两类信息可以在同一个文件里存储,也可以分文件存储。ONNX 是用同一个文件表示记录模型的结构和权重的。
我们部署时一般都默认这个参数为 True。如果 onnx 文件是用来在不同框架间传递模型(比如 PyTorch 到 Tensorflow)而不是用于部署,则可以令这个参数为 False。 - input_names, output_names
设置输入和输出张量的名称。如果不设置的话,会自动分配一些简单的名字(如数字)。
ONNX 模型的每个输入和输出张量都有一个名字。很多推理引擎在运行 ONNX 文件时,都需要以“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据。在进行跟张量有关的设置(比如添加动态维度)时,也需要知道张量的名字。
在实际的部署流水线中,我们都需要设置输入和输出张量的名称,并保证 ONNX 和推理引擎中使用同一套名称。 - opset_version
转换时参考哪个 ONNX 算子集版本,默认为 9。后文会详细介绍 PyTorch 与 ONNX 的算子对应关系。 - dynamic_axes
指定输入输出张量的哪些维度是动态的。
附代码
import torch
# 这一块区域为模型加载的步骤具体可以依据自己使用情况替换
from model import efficientnetv2_s as create_model
device = "cpu"
model = create_model(num_classes=2).to(device)
model_weight_path = "./weights1/model-54.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
batch_size = 1 # 批处理大小
input_shape = (3, 224, 224) # 输入数据
x = torch.randn(batch_size, *input_shape) # 生成张量
export_onnx_file = "test.onnx" # 目的ONNX文件名
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input": {0: "batch_size"}, # 批处理变量
"output": {0: "batch_size"}})
onnx模型可视化
import netron
modelData = "./test.onnx"
netron.start(modelData)
使用onnx推理
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# 自定义的数据增强
def get_test_transform():
return transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 推理的图片路径
image = Image.open('./0a0b8641cac0ce40315e38af020bb18f-device4-0-f_items6.jpg').convert('RGB')
img = get_test_transform()(image)
img = img.unsqueeze_(0) # -> NCHW, 1,3,224,224
# 模型加载
onnx_model_path = "test.onnx"
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
outs = resnet_session.run(None, inputs)[0]
print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])