将模型的输入直接传输到torch.onnx.export()函数中即可
import torch
import torch.nn as nn
# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1) # 一个全连接层,输入大小为10,输出大小为1
def forward(self, x):
x = self.fc(x)
return x
# 实例化模型
model = SimpleModel()
# 创建单个输入张量
input_data = torch.randn(1, 10)
# 调用 torch.onnx.export函数
torch.onnx.export(model,
input_data,
"model_temp.onnx",
input_names=["input"])