在上一篇文章中(Pytorch导出ONNX模型),我们介绍了如何导出单静态输入Pytorch模型的ONNX格式。然而,有很多AI模型需要多个输入参量来进行推理,比如 BERT。这时候我们在导出ONNX模型的时候就要对相应的伪输入参量进行处理,把他们放入一个Python元组中。这里构造一个需要三个输入参量的简单神经元网络为例:
import torch
import torch.onnx
# 假设模型是一个接受三个输入参数的Neural Network
class MyModel(torch.nn.Module):
def forward(self, input1, input2, input3):
return torch.relu(input1 + input2 + input3)
# 初始化模型
model = MyModel()
# 创建模型的输入张量
input1 = torch.randn(1, 3, 224, 224)
input2 = torch.randn(1, 3, 224, 224)
input3 = torch.randn(1, 1)
# 将输入张量放入一个元组中
inputs = (input1, input2, input3)
# 导出模型
torch.onnx.export(model,
inputs,
"model.onnx",
export_params = True,
opset_version = 11,
do_constant_folding = True,
input_names = ['input1', 'input2','input3'], # 注意这里
output_names = ['output'],
dynamic_axes = {'input1' : {0 : 'batch_size'},
'input2' : {0 : 'batch_size'},
'input3' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
注意Inputs = (input1, input2, input3)这一行以及input_names和dynamic_axes参数的变化。这里我们继续用 Netron 进行可视化分析,初步验证导出的ONNX模型,如下图所示:
整体来讲并没有什么难度,只是将输入参量按顺序放入了元组中并对torch.onnx.export函数的参数进行了简单修改。但是如果不注意,导出模型的时候可能会卡住,耽误很多时间。小伙伴们学会了吗。