关于pth转onnx以及检测onnx模型是否正确

文章介绍了如何将PyTorch模型(如MyModel)从.pth格式转换为.onnx格式,包括权重加载、模型评估和使用onnxruntime进行模型预测的过程。着重于.pth转onnx的关键步骤和注意事项。
摘要由CSDN通过智能技术生成

目录

一、pth转onnx

二、检测onnx模型


一、pth转onnx

load_state_dict:指从一个字典对象中加载神经网络的参数。

state_dict:可用于保存模型参数、超参数以及优化器的状态信息。只有可学习参数的层,如卷积层、线性层等才有state_dict。

torch.save():用来加载torch.save()保存的模型文件。

model = MyModel() # MyModel为自己用的模型

保存整个模型:torch.save(model, 'best_weight.pth')

只保存训练好的权重:torch.save(model.state_dict(),'best_weight.pth')

torch.load() :用于加载模型。

pth若只包含权重参数:

注意:若直接 (model.state_dict(), '*.pth'),则会得出一个错误的pth。应先加载出pth权重文件,再加载神经网络的参数,即 model .load_state_dict(torch.load('*.pth'))

model = MyModel()

model .load_state_dict(torch.load('*.pth'))

model.eval() # 不启用BatchNormalization和Dropout层 

import torch
from net import MyModel# 自己的模型

model_path = './best_weight.pth' # 这里的 pth 只包含了权重参数

model = MyModel()

model.load_state_dict(torch.load(model_path))

model.eval()

# 在机器学习模型开发和测试中,通常需要创建一个测试数据集用来评估模型的性能和准确性。
dummy_input = torch.randn(1, 3, 640, 640) # 虚拟输入,模拟输入数据的格式和形状。

torch.onnx.export(model, dummy_input, 'best_weight.onnx', verbose=True, input_names=['input'],
                  output_names=['output'])

print('Successful!')

二、检测onnx模型

import os, sys

sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


img = cv2.imread("./img/1.jpg")

img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_CUBIC)

to_tensor = transforms.ToTensor()
img = to_tensor(img)
img = img.unsqueeze_(0)

onnx_model_path = 'best_weight.onnx'
rnet_session = onnxruntime.InferenceSession(onnx_model_path)

# compute ONNX Runtime output prediction
inputs = {rnet_session.get_inputs()[0].name: to_numpy(img)}
outs = rnet_session.run(None, inputs) # 推理

print(outs)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值