pytorch模型转成ONNX

使用conda 虚拟环境
Python3.6.8 Pytorch1.1版本, torchvision 0.3.0版本

安装:pip install onnx onnxruntime

yolov3 pytorch 模型转onnx 代码:
参考pytorch教程:https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

# Some standard imports
from models import *
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import argparse
import PIL
from PIL import Image
import torchvision.transforms as transforms

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_def", type=str, default="config/yolov3-custom.cfg", help="path to model definition file")
    parser.add_argument("--data_config", type=str, default="config/custom.data", help="path to data config file")
    parser.add_argument("--pretrained_weights", type=str,
                        default="yolov3_custom_checkpoints_add/yolov3_ckpt_19.pth",
                        help="if specified starts from checkpoint model")
    opt = parser.parse_args()
    # print(opt)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_model = Darknet(opt.model_def).to(device)
    torch_model.load_state_dict(torch.load(opt.pretrained_weights, map_location='cpu'))
    torch_model.eval()

    # Input to the model
    img = Image.open("test_00002627.jpg")
    resize = transforms.Resize([416, 416])
    img = resize(img)
    to_tensor = transforms.ToTensor()
    img = to_tensor(img)
    img.unsqueeze_(0)
    x = img.cuda()
    print(x.shape)
    # exit()
    torch_out = torch_model(x)

    # Export the model
    torch.onnx.export(torch_model,               # model being run
                      x,                         # model input (or a tuple for multiple inputs)
                      "YoloV3_custom.onnx",        # where to save the model (can be a file or file-like object)
                      input_names=['input'],     # the model's input names
                      output_names=['output'],   # the model's output names
                      )

import onnx
onnx_model = onnx.load("YoloV3_custom.onnx")
onnx.checker.check_model(onnx_model)

# import onnxruntime
#
# ort_session = onnxruntime.InferenceSession("YoloV3_custom.onnx")
#
# def to_numpy(tensor):
#     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
#
# # compute ONNX Runtime output prediction
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
# ort_outs = ort_session.run(None, ort_inputs)
# print(ort_outs[0])
# print(torch_out)
#
# # compare ONNX Runtime and PyTorch results
# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-01, atol=1e-01)
#
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")


出现错误提示:

File “D:\Anaconda3\lib\site-packages\torch\onnx\symbolic.py”, line 90, in _parse_arg
raise RuntimeError("Failed to export an ONNX attribute, "
RuntimeError: Failed to export an ONNX attribute, since it’s not constant, please try to make things (e.g., kernel size) static if possible

问题查询参考:
https://github.com/onnx/tutorials/issues/137
https://blog.csdn.net/Cxiazaiyu/article/details/91129657

https://www.yht7.com/news/17243
https://www.pythonf.cn/read/142195

使用解决办法:Pytorch 版本降至1.0

  1. pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch==1.0.1

  2. pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchvision==0.2.1

博客pytorch 版本升至 1.2 解决,是版本匹配问题。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值