【pytorth】模型训练后导出至onnx

66 篇文章 1 订阅
65 篇文章 3 订阅

Pytorch是一款开源的深度学习框架,它可以帮助开发者在深度学习领域快速搭建和部署模型。基于Pytorch,开发者可以更容易地训练和部署深度学习模型,从而获得更高的效率。本文将主要介绍Pytorch模型训练后导出至onnx,并介绍onnx的概念及其在Pytorch中的应用。

ONNX(Open Neural Network Exchange)是一种开放的深度学习模型格式,开发者可以使用它将模型从一种框架转换到另一种框架。它支持从Pytorch到其他框架的模型转换,如Caffe2,MXNet等。使用ONNX,开发者可以更容易地将模型迁移到不同的框架,以实现更高的效率。

Pytorch支持将训练完成的模型导出到ONNX格式,以便将模型迁移到其他框架中。可以使用torch.onnx.export()函数将模型导出到ONNX格式,该函数接受一个Pytorch模型,一个输入张量列表和一个输出张量列表作为参数。

ONNX的使用可以帮助开发者更快地实现模型的迁移,从而大大提高开发者的效率。本文将主要介绍Pytorch模型训练后导出至onnx的过程,并介绍onnx的概念及其在Pytorch中的应用。通过本文,开发者可以更容易地理解onnx的概念,并学会如何将Pytorch模型导出至onnx。

from torchvision.transforms import Resize
import torch
from torchvision import transforms, datasets
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

resnet18 = models.resnet18(pretrained=False)

# 获取resnet18最后一层输出,输出为512,最后一层本来是用作 分类的,原始网络分为1000类
# 用 softmax函数或者 fully connected 函数,但是用 nn.identtiy() 函数把最后一层替换掉,相当于得到分类之前的特征!
#Identity模块,它将输入直接传递给输出,而不会对输入进行任何变换。
resnet18.fc = nn.Identity()

# 构建新的网络,将resnet18的输出作为输入
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #全卷积神经网络,不用调整输入大小
        self.resnet18 = resnet18
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        self.fc5 = nn.Linear(10, 2)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        x = self.resnet18(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.softmax(x)
        x=x.view(-1,2)
        return x
# 加载模型,注意下面的类适用于:保存为.pth模型时,只保存了权重的情况
class OnnxWorker():
    def __init__(self,model,path,onnx_save_path="model.onnx"):
        self.model = model
        self.load_state_dict(path)
        self.onnx_save_path=onnx_save_path
        self.export_onnx()
        self.simplify_onnx_model()
    def load_state_dict(self, path):
        self.model.load_state_dict(torch.load(path))

    def print_model(self):
        print(self.model)

    def export_onnx(self):
        dynamic_axes = {
            'input': {0: 'batch_size'},  # batch_size为动态
            'output': {0: 'batch_size'},
        }
        torch.onnx.export(self.model, torch.randn(1, 3, 112, 112), self.onnx_save_path,
                          export_params=True,
                          do_constant_folding=True,
                          opset_version=11,
                          input_names=['input'],
                          output_names=['output'],
                          verbose=False,
                          dynamic_axes=dynamic_axes)
        print("export_onnx")

    def simplify_onnx_model(self):
        import onnx
        from onnxsim import simplify
        onnx_model = onnx.load(self.onnx_save_path)  # load onnx model
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, self.onnx_save_path)  # save the simplyfied onnx model
        print("simplify_onnx_model")

#调用处:
worker=OnnxWorker(Net(),r"C:\Users\25360\Desktop\98-FaceLandmarks-main\bestmodel98.559.pth")
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值