关于pth转onnx以及检测onnx模型是否正确

目录

一、pth转onnx

二、检测onnx模型


一、pth转onnx

load_state_dict:指从一个字典对象中加载神经网络的参数。

state_dict:可用于保存模型参数、超参数以及优化器的状态信息。只有可学习参数的层,如卷积层、线性层等才有state_dict。

torch.save():用来加载torch.save()保存的模型文件。

model = MyModel() # MyModel为自己用的模型

保存整个模型:torch.save(model, 'best_weight.pth')

只保存训练好的权重:torch.save(model.state_dict(),'best_weight.pth')

torch.load() :用于加载模型。

pth若只包含权重参数:

注意:若直接 (model.state_dict(), '*.pth'),则会得出一个错误的pth。应先加载出pth权重文件,再加载神经网络的参数,即 model .load_state_dict(torch.load('*.pth'))

model = MyModel()

model .load_state_dict(torch.load('*.pth'))

model.eval() # 不启用BatchNormalization和Dropout层 

import torch
from net import MyModel# 自己的模型

model_path = './best_weight.pth' # 这里的 pth 只包含了权重参数

model = MyModel()

model.load_state_dict(torch.load(model_path))

model.eval()

# 在机器学习模型开发和测试中,通常需要创建一个测试数据集用来评估模型的性能和准确性。
dummy_input = torch.randn(1, 3, 640, 640) # 虚拟输入,模拟输入数据的格式和形状。

torch.onnx.export(model, dummy_input, 'best_weight.onnx', verbose=True, input_names=['input'],
                  output_names=['output'])

print('Successful!')

二、检测onnx模型

import os, sys

sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


img = cv2.imread("./img/1.jpg")

img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_CUBIC)

to_tensor = transforms.ToTensor()
img = to_tensor(img)
img = img.unsqueeze_(0)

onnx_model_path = 'best_weight.onnx'
rnet_session = onnxruntime.InferenceSession(onnx_model_path)

# compute ONNX Runtime output prediction
inputs = {rnet_session.get_inputs()[0].name: to_numpy(img)}
outs = rnet_session.run(None, inputs) # 推理

print(outs)
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要验证 PyTorch 模型ONNX 导出是否正确,您可以执行以下步骤: 1. 首先,使用 PyTorch 将模型导出为 ONNX 格式。这可以通过使用 `torch.onnx.export()` 函数来完成。例如: ```python import torch # 定义模型 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, x): return self.linear(x) # 创建示例输入 input_data = torch.randn(1, 10) # 将模型导出为 ONNX 格式 torch.onnx.export(MyModel(), input_data, "my_model.onnx") ``` 2. 安装 ONNX 运行时。可以通过以下命令来安装: ``` pip install onnxruntime ``` 3. 加载 ONNX 模型并执行推理。可以使用 `onnxruntime.InferenceSession()` 函数加载 ONNX 模型,并使用 `session.run()` 函数执行推理。例如: ```python import onnxruntime # 加载 ONNX 模型 session = onnxruntime.InferenceSession("my_model.onnx") # 创建输入数据 input_data = { session.get_inputs()[0].name: input_data.numpy() } # 执行推理 output_data = session.run(None, input_data) # 输出结果 print(output_data) ``` 4. 验证输出结果是否正确。最后,您需要验证 ONNX 导出的模型的输出是否与 PyTorch 模型的输出相同。您可以使用 PyTorch 运行模型并计算输出,并将其与使用 ONNX 运行模型时得到的输出进行比较。例如: ```python # 使用 PyTorch 运行模型 model = MyModel() output_data_torch = model(input_data).detach().numpy() # 验证输出是否相同 assert np.allclose(output_data[0], output_data_torch) ``` 如果断言没有触发异常,那么说明 ONNX 导出的模型的输出与 PyTorch 模型的输出相同,验证成功。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值