pytorch转onnx以及部分算子对照表

pytorch转onnx以及部分算子对照表

前言

作者写这篇文章纯粹是太过心烦,找点东西记录一下。刚好手上在做这件事,于是记录一下。

本文包含两部分:

  1. pytorch如何转onnx
  2. pytorch与onnx的部分算子对照表(仅限于torch1.8 -> onnx opset11)

pytorch转onnx

首先我们要明确两件事:

  1. torch.onnx.export()这个函数本身就已经可以导出onnx了
  2. 这个函数的原理是,拿着input去跑一遍模型,记录input在模型中的变化(可以理解为记录计算图或者记录模型参数等等,随你怎么理解,反正input会在模型中完整的走一遍),最后根据记录到的数据导出onnx

那么,接下来我会举一个简单的例子来说明:


import torch
import torch.onnx as onnx


# 可有可无,不是重点。定义一个简单的PyTorch模型,可以换成你自己的模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()

    def forward(self, x):
        x = x.unsqueeze(0)
        return x


# 可有可无,不是重点。创建一个示例模型实例,如果你有pth文件可以在这里加载
model = SimpleModel()

# 定义输入张量,这个要关注一下,张量的形状必须符合你模型的要输入的模型的张量的形状,这个input会在模型里完整的跑一遍
input_tensor = torch.randn(192, 36)

# 导出模型为ONNX格式
onnx_file_path = "result.onnx"
onnx.export(model, input_tensor, onnx_file_path)

总之,这样子之后我们就可以获得一个onnx模型了。这一小节结束!!!! 开心!!!!!

pytorch与onnx部分算子对照表

让人痛苦的事情开始了,我们为什么要转onnx?那当然是为了部署咯。那要是硬件不支持怎么办?
问得好!我带你们打!怎么会有这种令人头大的事情!!!! 这种时候就要开始准备改模型了,但是网上的算子对照表比较少,可能大部分模型都比较正常吧,我这个在网络里有unsqueeze,切片等等操作,这里记录一下咯。

再次声明:本人用的torch版本为1.8,目标onnx的opset_version为11,诸位如果有遇到下面没有记录到的算子,还请用上面的这段代码自己去尝试。onnx可视化可以用Netron来做:

unsqueeze

pytorch的unsqueeze就对应着onnx的unsqueeze:
在这里插入图片描述

cat

pytorch中的cat对应着onnx中的concat:

在这里插入图片描述
在这里插入图片描述
在这里我要多哔哔两句了,还记得我们前面说的onnx是先记录再转吗?clone在这里没有显式的结构可能就是因为它被整合到了concat里了,有错误的话欢迎指出。

expand

torch中的expand对应的onnx的结构就是这样子了,比较令人意外的是,我原本以为参数不同结果会不一样,没想到都是长这个样子,那就对不住了,我只好说:实践是检验真理的唯一标准!!!!!
在这里插入图片描述

在这里插入图片描述

repeat

torch中repeat对应的onnx算子差不多就是下面这样子,话说结构居然比expand还要简单一些。
在这里插入图片描述

在这里插入图片描述

切片

对不起,不会表达,反正直接把代码和结构图给你们看
在这里插入图片描述

在这里插入图片描述

修改一些切片的值

这一个的结构就比较多变了,进行不同的操作,组合在一起的形状总是不太一样,不过不管怎么变,总是会出现ScatterND这个算子,这次我会给出完整代码,感兴趣的话可以自己拿回去修改玩一玩:

import torch
import torch.onnx as onnx



# 定义一个简单的PyTorch模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()

    def forward(self, x):
    # 主要改这里
        x[0] = x[0] + 1
        return x

# 创建一个示例模型实例
model = SimpleModel()

# 定义输入张量
input_tensor = torch.randn(192, 36)

# 导出模型为ONNX格式
onnx_file_path = "kkk.onnx"
onnx.export(model, input_tensor, onnx_file_path, opset_version=11)


在这里插入图片描述

索引

pytorch中对张量的索引会对应着onnx的gather算子:
在这里插入图片描述
在这里插入图片描述

就先到这里吧,写的很乱,但是总归是记录了一点什么。希望能对你有帮助,润!

  • 5
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值