目录
一、pth转onnx
load_state_dict:指从一个字典对象中加载神经网络的参数。
state_dict:可用于保存模型参数、超参数以及优化器的状态信息。只有可学习参数的层,如卷积层、线性层等才有state_dict。
torch.save():用来加载torch.save()保存的模型文件。
model = MyModel() # MyModel为自己用的模型
保存整个模型:torch.save(model, 'best_weight.pth')
只保存训练好的权重:torch.save(model.state_dict(),'best_weight.pth')
torch.load() :用于加载模型。
pth若只包含权重参数:
注意:若直接 (model.state_dict(), '*.pth'),则会得出一个错误的pth。应先加载出pth权重文件,再加载神经网络的参数,即 model .load_state_dict(torch.load('*.pth'))。
model = MyModel()
model .load_state_dict(torch.load('*.pth'))
model.eval() # 不启用BatchNormalization和Dropout层
import torch
from net import MyModel# 自己的模型
model_path = './best_weight.pth' # 这里的 pth 只包含了权重参数
model = MyModel()
model.load_state_dict(torch.load(model_path))
model.eval()
# 在机器学习模型开发和测试中,通常需要创建一个测试数据集用来评估模型的性能和准确性。
dummy_input = torch.randn(1, 3, 640, 640) # 虚拟输入,模拟输入数据的格式和形状。
torch.onnx.export(model, dummy_input, 'best_weight.onnx', verbose=True, input_names=['input'],
output_names=['output'])
print('Successful!')
二、检测onnx模型
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
img = cv2.imread("./img/1.jpg")
img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_CUBIC)
to_tensor = transforms.ToTensor()
img = to_tensor(img)
img = img.unsqueeze_(0)
onnx_model_path = 'best_weight.onnx'
rnet_session = onnxruntime.InferenceSession(onnx_model_path)
# compute ONNX Runtime output prediction
inputs = {rnet_session.get_inputs()[0].name: to_numpy(img)}
outs = rnet_session.run(None, inputs) # 推理
print(outs)