Pytorch转onnx并部署

1 篇文章 0 订阅
1 篇文章 0 订阅

    训练好一个模型,或找到开源模型后,转为onnx并通过onnxruntime(CPU,GPU)来部署服务。

    Pytorch模型转为onnx可参考此文档

    文档缺少了一些内容,下面就按照自己的实践经验再整理一遍。

    首先,准备训练好的Pytorch模型,将模型通过model.eval()转为推理模式。detect_model:为举例用的网络。

import torch

model_path = "model/path"
model_weights = torch.load(model_path)
model = detect_model().to(device)
model.load_state_dict(model_weights)
model.eval()
转换之前需要指定输入参数。很多网络都是需要动态指定输入参数的。文档中提到了如何动态调整batch_size,以下代码是如何动态指定输入的宽高。
batch_size = 1
x = torch.randn(batch_size, 3, 320, 640, requires_grad=True).to(device)
torch_out = torch_model(x)

torch.onnx.export(torch_model, x, "detect.onnx",
                      export_params=True, opset_version=10,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes={
                          'input': [2, 3],
                          'output': [2, 3]
                      })

可通过代码来验证,转换是否成功

import onnx

onnx_model = onnx.load("detect.onnx")
onnx.checker.check_model(onnx_model)

转换为onnx模型后,需要使用onnxruntime来运行模型。

import onnxruntime

ort_session = onnxruntime.InferenceSession("detect.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

如果想调用GPU资源,需要卸载onnxruntime,并根据cuda版本安装正确的onnxruntime-gpu。并将上方代码修改为:

import onnxruntime

ort_session = onnxruntime.InferenceSession("detect.onnx")
ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(x, 'cuda', 0)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: ortvalue}
ort_outs = ort_session.run(None, ort_inputs)

 

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值