torch自定义算子转onnx模型报错

报错如下

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1195, in _export
    _C._check_onnx_proto(proto, full_check=True)
RuntimeError: No Op registered for MYSELU with domain_version of 11

==> Context: Bad node spec for node. Name: MYSELU_2 OpType: MYSELU

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "get_onnx.py", line 69, in <module>
    torch.onnx.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1197, in _export
    raise torch.onnx.CheckerError(e)
torch.onnx.CheckerError: No Op registered for MYSELU with domain_version of 11

==> Context: Bad node spec for node. Name: MYSELU_2 OpType: MYSELU

报错代码

import torch
import torch.nn as nn
import torch.onnx
import torch.autograd
import os

# 定义op
class MYSELUImpl(torch.autograd.Function):
    @staticmethod
    def symbolic(g, x, p) -> torch._C.Value:
        return g.op("MYSELU", x, p, 
            g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
            attr1_s="属性"
        )
    
    @staticmethod
    def forward(ctx, x, p):
        return x* 1 / (1 + torch.exp(-x))

class MySelu(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.param = nn.parameter.Parameter(torch.arange(n).float())

    def forward(self, x):
        # 按官方的说法不能用forwrad,只能用apply
        return MYSELUImpl.apply(x, self.param)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 3, padding=1)
        self.conv.weight.data.fill_(1)
        self.conv.bias.data.fill_(0)

        self.myselu = MySelu(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.myselu(x)
        return x


# 这个包对应opset11的导出代码,如果想修改导出的细节,可以在这里修改代码
# import torch.onnx.symbolic_opset11
print("对应opset文件夹代码在这里:", os.path.dirname(torch.onnx.__file__))

model = Model().eval()
input = torch.tensor([
    # batch 0
    [
        [1,   1,   1],
        [1,   1,   1],
        [1,   1,   1],
    ],
        # batch 1
    [
        [-1,   1,   1],
        [1,   0,   1],
        [1,   1,   -1]
    ]
], dtype=torch.float32).view(2, 1, 3, 3)

output = model(input)
print(f"inference output = \n{output}")

dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
    model,
    dummy,
    "myselu.onnx",
    input_names=["image"],
    output_names=["output"],
    opset_version=11,
    verbose=True,
    dynamic_axes={
        "image": {0:"batch"},
        "output": {0:"batch"}
    },
    # operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
print("Done.!")

解决方法,链接ONNX custom operator runtime error - #2 by sksenthilkumar - PyTorch Forums

torch.onnx.export中添加operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK

或者把pytorch版本退到1.1,1.2

或者添加enable_onnx_checker=False,只不过这个参数已被弃用和忽略

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TorchScript是一个将PyTorch模型换为可执行脚本的框架,而ONNX是一种跨平台的、开放的格式,可以用作深度学习模型的通用表示。将TorchScript模型换为ONNX格式,可以使模型在不同的平台、框架及硬件上运行,与其他深度学习框架进行集成,提高模型的生产性和可移植性。 TorchScript ONNX需要分为两步: 首先将TorchScript模型换为ONNX中间表示(IR,Intermediate Representation),可以使用torch.onnx.export方法将模型导出为ONNX格式。 导出ONNX模型时,需要指定输入的形状和类型以及输出的节点,此外还需指定输出的文件名。这里需要注意,PyTorch模型换为ONNX模型时,可能会发生精度损失或因为不支持的操作而失败,需要进行一些规避或调整操作。 然后,将ONNX的中间表示换为可执行模型,这可以通过onnxruntime等框架进行实现。onnxruntime是用于部署深度学习模型的高性能引擎,支持C++,C#,Python等多种编程语言和平台,可以在多种硬件上高效地运行深度学习模型。 在此之前还需要注意的是,随着TorchScript的不断发展,pytorchonnx的整合会越来越完善,也就会有越来越多的情况下,TorchScript模型换成ONNX时,可以不必换为中间表示,而可直接导出为ONNX模型。 总之,将TorchScript模型换为ONNX格式可以有效地提高模型的生产性和可移植性,并为模型的部署提供了更多的选择。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值