Pytorch是一款开源的深度学习框架,它可以帮助开发者在深度学习领域快速搭建和部署模型。基于Pytorch,开发者可以更容易地训练和部署深度学习模型,从而获得更高的效率。本文将主要介绍Pytorch模型训练后导出至onnx,并介绍onnx的概念及其在Pytorch中的应用。
ONNX(Open Neural Network Exchange)是一种开放的深度学习模型格式,开发者可以使用它将模型从一种框架转换到另一种框架。它支持从Pytorch到其他框架的模型转换,如Caffe2,MXNet等。使用ONNX,开发者可以更容易地将模型迁移到不同的框架,以实现更高的效率。
Pytorch支持将训练完成的模型导出到ONNX格式,以便将模型迁移到其他框架中。可以使用torch.onnx.export()函数将模型导出到ONNX格式,该函数接受一个Pytorch模型,一个输入张量列表和一个输出张量列表作为参数。
ONNX的使用可以帮助开发者更快地实现模型的迁移,从而大大提高开发者的效率。本文将主要介绍Pytorch模型训练后导出至onnx的过程,并介绍onnx的概念及其在Pytorch中的应用。通过本文,开发者可以更容易地理解onnx的概念,并学会如何将Pytorch模型导出至onnx。
from torchvision.transforms import Resize
import torch
from torchvision import transforms, datasets
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
resnet18 = models.resnet18(pretrained=False)
# 获取resnet18最后一层输出,输出为512维,最后一层本来是用作 分类的,原始网络分为1000类
# 用 softmax函数或者 fully connected 函数,但是用 nn.identtiy() 函数把最后一层替换掉,相当于得到分类之前的特征!
#Identity模块,它将输入直接传递给输出,而不会对输入进行任何变换。
resnet18.fc = nn.Identity()
# 构建新的网络,将resnet18的输出作为输入
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#全卷积神经网络,不用调整输入大小
self.resnet18 = resnet18
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Linear(64, 10)
self.fc5 = nn.Linear(10, 2)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.resnet18(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = F.relu(self.fc5(x))
x = self.softmax(x)
x=x.view(-1,2)
return x
# 加载模型,注意下面的类适用于:保存为.pth模型时,只保存了权重的情况
class OnnxWorker():
def __init__(self,model,path,onnx_save_path="model.onnx"):
self.model = model
self.load_state_dict(path)
self.onnx_save_path=onnx_save_path
self.export_onnx()
self.simplify_onnx_model()
def load_state_dict(self, path):
self.model.load_state_dict(torch.load(path))
def print_model(self):
print(self.model)
def export_onnx(self):
dynamic_axes = {
'input': {0: 'batch_size'}, # batch_size为动态
'output': {0: 'batch_size'},
}
torch.onnx.export(self.model, torch.randn(1, 3, 112, 112), self.onnx_save_path,
export_params=True,
do_constant_folding=True,
opset_version=11,
input_names=['input'],
output_names=['output'],
verbose=False,
dynamic_axes=dynamic_axes)
print("export_onnx")
def simplify_onnx_model(self):
import onnx
from onnxsim import simplify
onnx_model = onnx.load(self.onnx_save_path) # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, self.onnx_save_path) # save the simplyfied onnx model
print("simplify_onnx_model")
#调用处:
worker=OnnxWorker(Net(),r"C:\Users\25360\Desktop\98-FaceLandmarks-main\bestmodel98.559.pth")