使用TorchScript和ONNX进行PyTorch模型的部署



前言

在需要用到深度学习项目当中,把训练好的模型部署到生产环境是其中很关键很关键的一步。PyTorch提供了多种部署方式,其中TorchScript和ONNX是最常用的方法。本文将介绍如何使用这两种方法将PyTorch模型转换为可部署格式。


提示:以下是本篇文章正文内容

一、TorchScript

TorchScript是PyTorch提供的一种工具,它允许开发者将模型进行转化,使得模型能够在非PyThon环境下运行。

转换步骤

  • 准备模型:生成模型,一般继承自.nn

  • 脚本化与追踪:

    • 脚本化:适用于包含动态控制流的模型。
    import torch
    
    class MyModel(torch.nn.Module):
    	def forward(self, x):
        	return x * 2
        	
    model = MyModel()
    scripted_model = torch.jit.script(model)
    
    • 追踪:适用于简单前向传播的模型。
    import torch
    
    model = MyModel()
    example_input = torch.tensor([1.0])
    traced_model = torch.jit.trace(model, example_input)
    
  • 保存模型:

scripted_model.save("model.pt")
  • 加载模型:
loaded_model = torch.jit.load("model.pt")

二、ONNX

ONNX(Open Neural Network Exchange)是一种开放的深度学习模型格式,旨在支持跨框架模型的互操作性。通过将PyTorch模型导出为ONNX格式,开发者可以在其他深度学习框架中使用这些模型,增加了灵活性。

2.1 导出步骤

import torch
import torch.onnx

model = MyModel()
example_input = torch.tensor([1.0])

torch.onnx.export(model, example_input, "model.onnx", 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'])
# export_params=True:表示导出模型的参数。
# opset_version=11:指定ONNX操作集版本,确保与ONNX支持的运算符兼容。
# do_constant_folding=True:进行常量折叠优化,以减少模型大小和提高推理速度。
# input_names和output_names:用于在导出时命名输入和输出

# 加载模型(假设模型类为pth)
model = models.resnet50()  # 或者自定义模型
model.load_state_dict(torch.load("model.pth"))  # 加载权重
model.eval()  # 切换到评估模式
# 创建示例输入
example_input = torch.randn(1, 3, 224, 224)  # 根据模型的输入要求调整形状

# 导出为ONNX格式
torch.onnx.export(model, example_input, "model.onnx", 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'])

2.2 常见错误:ResNet导出时的报错

如果在转换ResNet模型为ONNX时出现错误信息,如“resnet is no atribute to backbone”,通常表示以下问题:
1.模型结构问题:可能是模型中有未连接的层,或组件未在前向传播中被调用。
2.不支持的运算符:使用的操作符在ONNX中不被支持。
3.输入数据不匹配:确保输入张量的形状和数据类型与模型的要求一致。

解决步骤:
1.检查模型定义,确保所有层都正确连接。
2.使用标准的ResNet实现,或者确保自定义层是可导出的。
3.使用torch.onnx.export的verbose=True参数,获取更多调试信息。

三、使用pretrained=False

当使用pretrained=False时,表示将使用从头开始训练的模型,而不是加载已经在大规模数据集(如ImageNet)上预训练的模型。这在以下情况下非常有用:

  • 自定义任务:数据集与预训练模型的数据集差异较大时,从头开始训练模型可能更好。
  • 完全控制:从零开始训练模型可以让你选择初始权重和超参数。
  • 实验性需求:测试新的架构或训练策略,而不受预训练权重的影响。

使用步骤

如果选择不使用预训练的权重,可以这样操作:

import torch
import torchvision.models as models

# 设定预训练为False
model = models.resnet50(pretrained=False)

# 修改模型的最后一层以适应你的任务
num_classes = 10  # 根据你的数据集的类别数量
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

# 接下来可以进行训练

在训练完成后,可以像之前一样导出为ONNX格式:

example_input = torch.randn(1, 3, 224, 224)  # 根据模型的输入要求

torch.onnx.export(model, example_input, "resnet_no_pretrained.onnx",
                  export_params=True,
                  opset_version=11,
                  do_constant_folding=True,
                  input_names=['input'],
                  output_names=['output'])

另一种方法可查看ncnn之resnet图像分类网络模型部署中关于ONNX模型导出的内容

总结

提示:这里对文章进行总结:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值