将模型的多个输入构造为dict,输入到torch.onnx.export()函数中,实现多输入模型导出onnx
import torch.onnx
import torch
import torch.nn as nn
# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1) # 一个全连接层,输入大小为10,输出大小为1
def forward(self, x1, x2):
x = torch.concatenate((x1, x2), dim=1)
x = self.fc(x)
return x
# 实例化模型
model = SimpleModel()
# 创建多个输入张量
input_data1 = torch.randn(1, 5)
input_data2 = torch.randn(1, 5)
# 将多个输入张量打包成字典
input_dict = {"x1": input_data1, "x2": input_data2}
# 调用 torch.onnx.export 函数并传递字典作为输入
torch.onnx.export(model, input_dict, "model_temp.onnx")