pytorch导出onnx

多输入网络构建、模型导出生成ONNX格式、导出的ONNX模型有效性验证三个部分

import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime

# -----------------------------------#
#   定义一个简单的双输入网络   
# -----------------------------------#
class MyNet_multi_input(nn.Module):
    def __init__(self, num_classes=10):
        super(MyNet_multi_input, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)   # input[3, 28, 28]  output[32, 14, 14]
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)   # input[1, 28, 28]  output[32, 14, 14]
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)

        self.fc = nn.Linear(48 * 14 * 14, num_classes)

    def forward(self, x, y):
        x = self.relu1(self.bn1(self.conv1(x)))
        y = self.relu2(self.bn2(self.conv2(y)))
        z = torch.cat((x, y), 1)
        z = torch.flatten(z, start_dim=1)
        z = self.fc(z)
        return z

# -----------------------------------#
#   导出ONNX模型函数
# -----------------------------------#
def multi_input_model_convert_onnx(model, input_shape, output_path):
    dummy_input1 = torch.randn(1, 3, input_shape[0], input_shape[1])
    dummy_input2 = torch.randn(1, 1, input_shape[0], input_shape[1])
    input_names = ["input1", "input2"]        # 导出的ONNX模型输入节点名称
    output_names = ["output1"]      # 导出的ONNX模型输出节点名称

    torch.onnx.export(
        model,
        (dummy_input1, dummy_input2),
        output_path,
        verbose=False,          # 如果指定为True,在导出的ONNX中会有详细的导出过程信息description
        opset_version=11,       # 地平线目前支持为 10 or 11
        input_names=input_names,
        output_names=output_names,
    )

if __name__ == '__main__':
    multi_input_model = MyNet_multi_input()
    # print(multi_input_model)
    # 建议将模型转成 eval 模式
    multi_input_model.eval()
    # 网络模型的输入尺寸
    input_shape = (28, 28)      
    # ONNX模型输出路径
    multi_input_model_output_path = './multi_input_model.onnx'

    # 导出为ONNX模型
    multi_input_model_convert_onnx(multi_input_model, input_shape, multi_input_model_output_path)
    print("multi_input_model convert onnx finsh.")

    # -----------------------------------#
    #   复杂模型可以使用下面的方法进行简化   
    # -----------------------------------#
    # import onnxsim
    # multi_input_model_sim = onnxsim.simplify(onnx.load(multi_input_model_output_path))
    # onnx.save(multi_input_model_sim[0], "multi_input_model_sim.onnx")

    # -----------------------------------------------------------------------#
    #   第一轮ONNX模型有效性验证,用来检查模型是否满足 ONNX 标准   
    #   这一步是必要的,因为无论模型是否满足标准,ONNX 都允许使用 onnx.save 存储模型,
    #   我们都不会希望生成一个不满足标准的模型~
    # -----------------------------------------------------------------------#
    onnx_model = onnx.load(multi_input_model_output_path)
    onnx.checker.check_model(multi_input_model_output_path)
    print("onnx model check_1 finsh.")

    # ----------------------------------------------------------------#
    #   第二轮ONNX模型有效性验证,用来验证ONNX模型与Pytorch模型的推理一致性   
    # ----------------------------------------------------------------#
    # 随机初始化一个模型输入,注意输入分辨率
    x = torch.randn(size=(1, 3, input_shape[0], input_shape[1]))
    y = torch.randn(size=(1, 1, input_shape[0], input_shape[1]))
    # torch模型推理
    with torch.no_grad():
        torch_out = multi_input_model(x, y)
    # print(torch_out)            # tensor([[-0.5728,  0.1695, ..., -0.3256,  1.1357, -0.4081]])
    # print(type(torch_out))      # <class 'torch.Tensor'>

    # 初始化ONNX模型
    ort_session = onnxruntime.InferenceSession(multi_input_model_output_path)
    # ONNX模型输入初始化
    ort_inputs = {ort_session.get_inputs()[0].name: x.numpy(), ort_session.get_inputs()[1].name: y.numpy()}
    # ONNX模型推理
    ort_outs = ort_session.run(None, ort_inputs)
    # print(ort_outs)             # [array([[-0.5727689 ,  0.16947027,  ..., -0.32555276,  1.13574252, -0.40812433]], dtype=float32)]
    # print(type(ort_outs))       # <class 'list'>,里面是个numpy矩阵
    # print(type(ort_outs[0]))    # <class 'numpy.ndarray'>
    ort_outs = ort_outs[0]        # 把内部numpy矩阵取出来,这一步很有必要

    # print(torch_out.numpy().shape)      # (1, 10)
    # print(ort_outs.shape)               # (1, 10)

    # ----------------------------------------------------------------#
    # 比较实际值与期望值的差异,通过继续往下执行,不通过引发AssertionError
    # 需要两个numpy输入
    # ----------------------------------------------------------------#
    np.testing.assert_allclose(torch_out.numpy(), ort_outs, rtol=1e-03, atol=1e-05)
    print("onnx model check_2 finsh.")

 

更多内容可参考 PyTorch官方导出ONNX模型教程

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值