知识复习 - Pytorch导出ONNX模型 多输入参量

在上一篇文章中(Pytorch导出ONNX模型),我们介绍了如何导出单静态输入Pytorch模型的ONNX格式。然而,有很多AI模型需要多个输入参量来进行推理,比如 BERT。这时候我们在导出ONNX模型的时候就要对相应的伪输入参量进行处理,把他们放入一个Python元组中。这里构造一个需要三个输入参量的简单神经元网络为例:

import torch
import torch.onnx

# 假设模型是一个接受三个输入参数的Neural Network
class MyModel(torch.nn.Module):
    def forward(self, input1, input2, input3):
        return torch.relu(input1 + input2 + input3)

# 初始化模型
model = MyModel()

# 创建模型的输入张量
input1 = torch.randn(1, 3, 224, 224)
input2 = torch.randn(1, 3, 224, 224)
input3 = torch.randn(1, 1)

# 将输入张量放入一个元组中
inputs = (input1, input2, input3)

# 导出模型
torch.onnx.export(model,              
                  inputs,              
                  "model.onnx",        
                  export_params = True,  
                  opset_version = 11,    
                  do_constant_folding = True, 
                  input_names = ['input1', 'input2','input3'],    # 注意这里
                  output_names = ['output'], 
                  dynamic_axes = {'input1' : {0 : 'batch_size'},   
                                'input2' : {0 : 'batch_size'},
                                'input3' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}})

注意Inputs = (input1, input2, input3)这一行以及input_names和dynamic_axes参数的变化。这里我们继续用 Netron 进行可视化分析,初步验证导出的ONNX模型,如下图所示:

整体来讲并没有什么难度,只是将输入参量按顺序放入了元组中并对torch.onnx.export函数的参数进行了简单修改。但是如果不注意,导出模型的时候可能会卡住,耽误很多时间。小伙伴们学会了吗。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值