ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

目录

PyTorch简介

导入转换器

快速浏览模型

将PyTorch模型转换为ONNX

摘要和后续步骤

参考文献


系列文章列表如下:

ONNX系列一 --- 带有ONNX的便携式神经网络

ONNX系列二 --- 使用ONNX使Keras模型可移植

ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

ONNX系列四 --- 使用ONNX使TensorFlow模型可移植

ONNX系列五 --- 在C#中使用可移植的ONNX AI模型

ONNX系列六 --- 在Java中使用可移植的ONNX AI模型

ONNX系列七 --- 在Python中使用可移植的ONNX AI模型

在关于2020年使用便携式神经网络的系列文章中,您将学习如何将PyTorch模型转换为便携式ONNX格式。

由于ONNX并不是用于构建和训练模型的框架,因此我将首先简要介绍PyTorch。对于从头开始并考虑将PyTorch作为构建和训练其模型的框架的工程师而言,这将很有用。

PyTorch简介

PyTorch2016年发布,由FacebookAI研究实验室(FAIR)开发。它已成为研究人员尝试自然语言处理和计算机视觉的首选框架。这很有趣,因为TensorFlow在生产环境中得到了更广泛的使用。研究实验室和生产之间的这种二分法确实强调了像ONNX这样的标准的价值,该标准为模型提供了通用格式,并提供了可用于所有流行编程语言的运行时。举个例子,假设一个组织不想在其生产环境中拥有所有可能的框架,而是希望对其进行标准化。如果没有ONNX,则需要在选择用于生产和部署的框架中重新实现该模型。这是一项艰巨的工程任务。使用ONNXPyTorch模型可以仅用几行代码导出,并可以从任何语言中使用。生产中仅需要ONNX Runtime

导入转换器

PyTorch的维护人员已将ONNX转换器集成到PyTorch本身中。您不需要安装任何其他软件包。安装PyTorch之后,您可以通过在模块中包含以下导入内容来访问PyTorchONNX转换器:

import torch

导入torch模块后,您可以按以下方式访问转换函数:

torch.onnx.export()

希望这是其他框架将采用的一种做法。使用框架本身对转换器进行打包和版本控制可以减少安装的软件包数量,并且还可以防止框架和转换器之间的版本不匹配。

快速浏览模型

在转换PyTorch模型之前,我们需要查看创建模型的代码,以确定输入的形状。下面的代码创建一个PyTorch模型,该模型预测MNIST数据集中发现的数字。对模型层的详细描述超出了本文的范围,但是我们确实需要注意输入的形状。在这里是784。更具体地说,这段代码正在创建一个模型,其中输入将是一个平坦的张量,该张量是784个浮点的数组。784的意义是什么?好吧,MNIST数据集中的每个图像都是28×28像素的图像。28×28 =784。因此,一旦展平,我们的输入为784个浮点数,其中每个浮点数都代表灰色阴影。底线:该模型期望单个图像中有784个浮点数。它不期望多维数组,也不期望一批图像。一次只有一个预测。将模型转换为ONNX时,这是一个重要的事实。

def build_model():
   # Layer details for the neural network
   input_size = 784
   hidden_sizes = [128, 64]
   output_size = 10

   # Build a feed-forward network
   model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                       nn.ReLU(),
                       nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                       nn.ReLU(),
                       nn.Linear(hidden_sizes[1], output_size),
                       nn.LogSoftmax(dim=1))
   return model

PyTorch模型转换为ONNX

以下函数说明了如何使用torch.onnx.export函数。正确使用此函数有一些技巧。第一个也是最重要的技巧是正确设置样本输入。sample_input参数用于确定ONNX模型的输入。export_to_onnx函数将接受您提供的任何内容(只要它是张量即可),并且转换将正确无误地进行。但是,如果示例输入的形状错误,则尝试从ONNX运行时运行ONNX模型时会收到错误消息。

def export_to_onnx(model):
   sample_input = torch.randn(1, 784)
   print(type(sample_input))
   torch.onnx.export(model,            # model being run
                     sample_input,     # model input (or a tuple for multiple inputs)
                     ONNX_MODEL_FILE,  # where to save the model
                     input_names = ['input'],   # the model's input names
                     output_names = ['output'] # the model's output names

   # Set metadata on the model.
   onnx_model = onnx.load(ONNX_MODEL_FILE)
   meta = onnx_model.metadata_props.add()
   meta.key = "creation_date"
   meta.value = datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
   meta = onnx_model.metadata_props.add()
   meta.key = "author"
   meta.value = 'keithpij'
   onnx_model.doc_string = 'MNIST model converted from Pytorch'
   onnx_model.model_version = 3  # This must be an integer or long.
   onnx.save(onnx_model, ONNX_MODEL_FILE)

如果原始的PyTorch模型设计为可以接收100张图像,那么此示例输入就可以了。但是,如前所述,我们的模型被设计为在进行预测时一次只能接受一张图像。如果使用此样本输入导出模型,则在运行模型时会出现错误。

将元数据添加到模型的代码是最佳实践。随着用于训练模型的数据的发展,模型也将随之发展。因此,最好将元数据添加到模型中,以便您可以将其与以前的模型区分开。上面的示例将模型的简短描述添加到doc_string属性中并设置版本。creation_dateauthor是添加到metadata_props属性包中的自定义属性。您可以使用此属性包自由创建尽可能多的自定义属性。不幸的是,model_version属性需要一个整数或长整数,因此您将无法使用major.minor.revision语法像服务一样对其进行版本控制。此外,导出功能会自动将模型保存到文件中,因此要添加此元数据,您需要重新打开文件并重新保存它。

摘要和后续步骤

在本文中,我为那些寻求用于构建和训练神经网络的深度学习框架的人提供了PyTorch的简要概述。然后,我展示了如何使用转换工具将PyTorch模型转换为ONNX格式,该工具已经是PyTorch本身的一部分。我还展示了将元数据添加到导出模型的最佳实践。

由于本文的目的是演示将Keras模型转换为ONNX格式,因此我没有详细介绍如何构建和培训Keras模型。本文的代码示例包含探索Keras本身的代码。所述keras_mnist.py模块是一个完整的端至端演示的是示出了如何加载数据、探索图像和训练模型。

接下来,我们将研究将TensorFlow模型转换为ONNX

参考文献

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值