Pytouch模型转ONNX
将PyTorch转为ONNX模型很简单,使用torch.onnx.export()函数即可。
函数说明:
torch.onnx.export()
功能:
将.pth模型转为onnx文件导出
参数:
- model(torch.nn.Module): pth模型文件
- args (tuple of arguments): 模型的输入,模型的尺寸
- export_params (bool, default True):
如果指定为True或者默认,参数也会被导出,如果要导出一个没训练过的就设置为False - verbose (bool, default False): 导出轨迹的调试描述
- training (bool, default False) :在训练模式下导出模型。目前,ONNX导出的模型只是为了做推断,通常不需要将其设置为True;
- input_names (list of strings, default empty list) :onnx文件的输入名称, 可以随便取
- output_names (list of strings, default empty list) :onnx文件的输出名称,可以随便取
- opset_version:默认为9
- dynamic_axes – {‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 :
‘batch_size’}})
转换示例代码
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# 创建模型实例并加载预训练权重(如果有)
model = SimpleModel()
# model.load_state_dict(torch.load('path_to_your_model.pth'))
model.eval() # 设置模型为评估模式
# 创建一个示例输入张量
dummy_input = torch.randn(1, 10) # 这里的大小应该与你的模型输入大小一致
import torch.onnx
# 导出模型为ONNX格式
torch.onnx.export(
model, # 要转换的模型
dummy_input, # 示例输入张量
"simple_model.onnx", # 导出的ONNX模型文件名
verbose=True, # 是否输出ONNX导出器日志记录
input_names=['input'], # 输入节点的名称
output_names=['output'], # 输出节点的名称
opset_version=14, # ONNX opset版本
do_constant_folding=True, # 是否执行常量折叠优化
)
print("模型已成功导出为 simple_model.onnx")
验证模型是否有效:
import onnx
# Preprocessing: load the ONNX model
model_path = 'simple_model.onnx'
onnx_model = onnx.load(model_path)
print('The model is:\n{}'.format(onnx_model))
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is invalid: %s' % e)
else:
print('The model is valid!')
ONNX模型精度转化
from onnx import load_model, save_model
from onnxmltools.utils import convert_float_to_float16
text_fp16_onnx_path = "simple_model_fp16.onnx"
text_fp32_onnx_model = load_model("simple_model.onnx")
text_fp16_onnx_model = convert_float_to_float16(text_fp32_onnx_model, keep_io_types=True, disable_shape_infer=True)
save_model(text_fp16_onnx_model,
text_fp16_onnx_path,
location=f"{text_fp16_onnx_path}.extra_file",
save_as_external_data=True,
all_tensors_to_one_file=True,
size_threshold=1024,
convert_attribute=True)