Pytorch导出模型为onnx

1. 模型单输入输出

# 创建模型的输入数据
input = torch.rand(1, 3, 224, 224)

# 导出模型
torch.onnx.export(model ,               		# 模型
                  input ,              			# 伪输入 提供模型输入的形状
                  r"model.onnx",  				# 保存的文件名
                  export_params=True,           # 是否保存模型参数
                  opset_version=11,             # onnx版本
                  do_constant_folding=True,     # 是否执行常量折叠优化
                  input_names = ['input'],      # 输入名称
                  output_names = ['output'],    # 输出名称
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 指定动态轴 dynamic_axes 参数的格式为一个字典,
                                'output' : {0 : 'batch_size'}},  # 其中键是输入或输出张量的名称,值是另一个字典,指定了动态轴的索引及其对应的名称。
                  # verbose=True,               # 是否打印详细信息
                  )

2. 模型多输入

import torch
import torch.onnx

# 假设模型是一个接受两个输入参数的模型
class MyModel(torch.nn.Module):
    def forward(self, input1, input2):
        return input1 + input2

# 初始化模型
model = MyModel()

# 创建模型的输入数据
input1 = torch.randn(1, 3, 224, 224)
input2 = torch.randn(1, 1)

# 将输入数据放入一个元组中
inputs = (input1, input2)

# 导出模型
torch.onnx.export(model,              
                  inputs,              
                  "model.onnx",        
                  export_params=True,  
                  opset_version=11,    
                  do_constant_folding=True, 
                  input_names = ['input1', 'input2'],   
                  output_names = ['output'], 
                  dynamic_axes={'input1' : {0 : 'batch_size'},   
                                'input2' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}})

3. 模型多输出

torch.onnx.export()output_names是以模型多个输出的顺序来命名的。
# 模型输出 元组形式
class MyModel(torch.nn.Module):
    def forward(self, x):
        logit = model_1(x)
        pre_logits = model_2(x)
        return logit, pre_logits

# 初始化模型
model = MyModel()

# 创建模型的输入数据
input = torch.rand(1, 3, 224, 224)

# 导出模型
torch.onnx.export(model , 
                  input , 
                  r"model.onnx", 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['logit', 'pre_logits'], 
                  dynamic_axes={'input': {0: 'batch_size'}, 
                  				'logit': {0: 'batch_size'},
                  				'pre_logits': {0: 'batch_size'}})

上例模型输出是元组形式,因此onnx中的output_names[0]指代模型输出的outputs['logit'],而output_names[1]指代模型输出的outputs['pre_logits']

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Pytorch 模型导出ONNX 或 TensorRT 格式的具体步骤如下: ### 导出ONNX 格式 1. 安装 onnx 包:`pip install onnx` 2. 加载 Pytorch 模型并将其转换为 ONNX 模型: ```python import torch import torchvision import onnx # 加载 Pytorch 模型 model = torchvision.models.resnet18(pretrained=True) # 转换为 ONNX 模型 dummy_input = torch.randn(1, 3, 224, 224) input_names = ["input"] output_names = ["output"] onnx_path = "resnet18.onnx" torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=input_names, output_names=output_names) ``` 3. 导入 ONNX 模型: ```python import onnx # 加载 ONNX 模型 onnx_path = "resnet18.onnx" model = onnx.load(onnx_path) ``` ### 导出为 TensorRT 格式 1. 安装 TensorRT 并设置环境变量: ```python # 安装 TensorRT !pip install nvidia-pyindex !pip install nvidia-tensorrt # 设置 TensorRT 环境变量 import os os.environ["LD_LIBRARY_PATH"] += ":/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu" ``` 2. 加载 Pytorch 模型并将其转换为 TensorRT 模型: ```python import tensorrt as trt import pycuda.driver as cuda import torch import torchvision # 加载 Pytorch 模型 model = torchvision.models.resnet18(pretrained=True) # 转换为 TensorRT 模型 TRT_LOGGER = trt.Logger(trt.Logger.WARNING) trt_runtime = trt.Runtime(TRT_LOGGER) with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser: builder.max_workspace_size = 1 << 30 builder.max_batch_size = 1 # 加载 ONNX 模型 onnx_path = "resnet18.onnx" with open(onnx_path, "rb") as f: parser.parse(f.read()) # 构建 TensorRT 引擎 engine = builder.build_cuda_engine(network) # 保存 TensorRT 引擎 with open("resnet18.trt", "wb") as f: f.write(engine.serialize()) ``` 3. 导入 TensorRT 模型: ```python import tensorrt as trt # 加载 TensorRT 模型 trt_path = "resnet18.trt" with open(trt_path, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.WARNING)) as runtime: engine = runtime.deserialize_cuda_engine(f.read()) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值