pytorch导出onnx不成功的时候如何解决? - 学习记录

本文介绍了如何通过升级opset版本、替换PyTorch中的算子组合、onnx注册新算子以及使用onnx-surgeon创建插件来处理PyTorch模型导出到ONNX时遇到的不支持算子问题,如asinh的案例。
摘要由CSDN通过智能技术生成

1.1、修改opset的版本

  • 查看不支持的算子在新的opset中是否被支持
  • 如果不考虑自己搭建plugin的话,也需要看看onnx-trt中这个算子是否被支持

(去查看onnx官网文档)

1.2、替换pytorch中的算子组合

  • 把某些计算替换成onnx可以识别的

1.3、onnx 注册算子

  • 在pytorch登记onnx中某些算子
  • 有可能onnx中有支持,但没有被登记

1.3.1 样例

python1

import torch
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = torch.asinh(x)
        return x

def infer():
    input = torch.rand(1, 5)
    
    model = Model()
    x = model(input)
    print("input is: ", input.data)
    print("result is: ", x.data)

def export_norm_onnx():
    input   = torch.rand(1, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-asinh.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 9)
    print("Finished normal onnx export")

if __name__ == "__main__":
    infer()

    # 这里导出asinh会出现错误。
    # Pytorch可以支持asinh的同时,
    #   def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...

    # 从onnx支持的算子里面我们可以知道自从opset9开始asinh就已经被支持了
    #   asinh is suppored since opset9

    export_norm_onnx()

python1导不出onnx,可以去onnx官网查看一下,是否支持了这个算子
在这里插入图片描述
可以看到,其实已经支持了,但还是导不出。所以这里要先注册一下这个算子。
python2

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic

# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
#   g: 就是graph,计算图
#   也就是说,在计算图中添加onnx算子
#   由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
#   symblic的参数需要与Pytorch的asinh接口函数的参数对齐
#       def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的

# 那么aten是什么呢?
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)


# 这里容易混淆的地方:
# 1. register_op中的第一个参数是PyTorch中的算子名字: aten::asinh
# 2. g.op中的第一个参数是onnx中的算子名字: Asinh

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = torch.asinh(x)
        return x


def validate_onnx():
    input = torch.rand(1, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-asinh.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    ", x)

def export_norm_onnx():
    input   = torch.rand(1, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-asinh.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

    # 自定义完onnx以后必须要进行一下验证
    validate_onnx()

主要是这个函数,用来注册算子

def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

1.4、直接修改onnx,创建plugin

• 使用onnx-surgeon
• 一般是用在加速某些算子上使用

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值