Pytorch模型转ONNX模型

Pytouch模型转ONNX

将PyTorch转为ONNX模型很简单,使用torch.onnx.export()函数即可。
函数说明:

torch.onnx.export()
功能:
将.pth模型转为onnx文件导出

参数:

  • model(torch.nn.Module): pth模型文件
  • args (tuple of arguments): 模型的输入,模型的尺寸
  • export_params (bool, default True):
    如果指定为True或者默认,参数也会被导出,如果要导出一个没训练过的就设置为False
  • verbose (bool, default False): 导出轨迹的调试描述
  • training (bool, default False) :在训练模式下导出模型。目前,ONNX导出的模型只是为了做推断,通常不需要将其设置为True;
  • input_names (list of strings, default empty list) :onnx文件的输入名称, 可以随便取
  • output_names (list of strings, default empty list) :onnx文件的输出名称,可以随便取
  • opset_version:默认为9
  • dynamic_axes – {‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 :
    ‘batch_size’}})

转换示例代码

import torch
import torch.nn as nn


# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

    # 创建模型实例并加载预训练权重(如果有)


model = SimpleModel()
# model.load_state_dict(torch.load('path_to_your_model.pth'))
model.eval()  # 设置模型为评估模式

# 创建一个示例输入张量
dummy_input = torch.randn(1, 10)  # 这里的大小应该与你的模型输入大小一致

import torch.onnx

# 导出模型为ONNX格式
torch.onnx.export(
    model,  # 要转换的模型
    dummy_input,  # 示例输入张量
    "simple_model.onnx",  # 导出的ONNX模型文件名
    verbose=True,           # 是否输出ONNX导出器日志记录
    input_names=['input'],  # 输入节点的名称
    output_names=['output'],  # 输出节点的名称
    opset_version=14,  # ONNX opset版本
    do_constant_folding=True,  # 是否执行常量折叠优化
)

print("模型已成功导出为 simple_model.onnx")

验证模型是否有效:

import onnx

# Preprocessing: load the ONNX model
model_path = 'simple_model.onnx'
onnx_model = onnx.load(model_path)

print('The model is:\n{}'.format(onnx_model))

# Check the model
try:
    onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
    print('The model is invalid: %s' % e)
else:
    print('The model is valid!')

ONNX模型精度转化

from onnx import load_model, save_model
from onnxmltools.utils import convert_float_to_float16


text_fp16_onnx_path = "simple_model_fp16.onnx"
text_fp32_onnx_model = load_model("simple_model.onnx")
text_fp16_onnx_model = convert_float_to_float16(text_fp32_onnx_model, keep_io_types=True, disable_shape_infer=True)
save_model(text_fp16_onnx_model,
            text_fp16_onnx_path,
            location=f"{text_fp16_onnx_path}.extra_file",
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            size_threshold=1024,
            convert_attribute=True)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

现实、狠残酷

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值