1. 模型单输入输出
# 创建模型的输入数据
input = torch.rand(1, 3, 224, 224)
# 导出模型
torch.onnx.export(model , # 模型
input , # 伪输入 提供模型输入的形状
r"model.onnx", # 保存的文件名
export_params=True, # 是否保存模型参数
opset_version=11, # onnx版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names = ['input'], # 输入名称
output_names = ['output'], # 输出名称
dynamic_axes={'input' : {0 : 'batch_size'}, # 指定动态轴 dynamic_axes 参数的格式为一个字典,
'output' : {0 : 'batch_size'}}, # 其中键是输入或输出张量的名称,值是另一个字典,指定了动态轴的索引及其对应的名称。
# verbose=True, # 是否打印详细信息
)
2. 模型多输入
import torch
import torch.onnx
# 假设模型是一个接受两个输入参数的模型
class MyModel(torch.nn.Module):
def forward(self, input1, input2):
return input1 + input2
# 初始化模型
model = MyModel()
# 创建模型的输入数据
input1 = torch.randn(1, 3, 224, 224)
input2 = torch.randn(1, 1)
# 将输入数据放入一个元组中
inputs = (input1, input2)
# 导出模型
torch.onnx.export(model,
inputs,
"model.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names = ['input1', 'input2'],
output_names = ['output'],
dynamic_axes={'input1' : {0 : 'batch_size'},
'input2' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
3. 模型多输出
在
torch.onnx.export()
中output_names
是以模型多个输出的顺序来命名的。
# 模型输出 元组形式
class MyModel(torch.nn.Module):
def forward(self, x):
logit = model_1(x)
pre_logits = model_2(x)
return logit, pre_logits
# 初始化模型
model = MyModel()
# 创建模型的输入数据
input = torch.rand(1, 3, 224, 224)
# 导出模型
torch.onnx.export(model ,
input ,
r"model.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['logit', 'pre_logits'],
dynamic_axes={'input': {0: 'batch_size'},
'logit': {0: 'batch_size'},
'pre_logits': {0: 'batch_size'}})
上例模型输出是元组形式,因此
onnx
中的output_names[0]
指代模型输出的outputs['logit']
,而output_names[1]
指代模型输出的outputs['pre_logits']
。