ONNX系列六 --- 在Java中使用可移植的ONNX AI模型

目录

安装和导入ONNX运行时

载入ONNX模型

使用ONNX运行时进行预测

摘要和后续步骤

参考文献


系列文章列表如下:

ONNX系列一 --- 带有ONNX的便携式神经网络

ONNX系列二 --- 使用ONNX使Keras模型可移植

ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

ONNX系列四 --- 使用ONNX使TensorFlow模型可移植

ONNX系列五 --- 在C#中使用可移植的ONNX AI模型

ONNX系列六 --- 在Java中使用可移植的ONNX AI模型

ONNX系列七 --- 在Python中使用可移植的ONNX AI模型

在关于2020年使用便携式神经网络的系列文章中,您将学习如何在x64架构上安装ONNX并在Java中使用它。

微软与FacebookAWS共同开发了ONNXONNX格式和ONNX运行时都具有行业支持,以确保所有重要框架都能够将其图形导出到ONNX,并且这些模型可以在任何硬件配置上运行。

ONNX Runtime是用于运行已转换为ONNX格式的机器学习模型的引擎。传统机器学习模型和深度学习模型(神经网络)都可以导出为ONNX格式。运行时可以在LinuxWindowsMac上运行,并且可以在各种芯片体系结构上运行。它还可以利用诸如GPUTPU之类的硬件加速器。但是,没有针对操作系统,芯片体系结构和加速器的每种组合的安装包,因此,如果不使用任何一种常见组合,则可能需要从源代码构建运行时。检查ONNX运行时网站以获得所需组合的安装说明。本文将展示如何在具有默认CPUx64体系结构和具有GPUx64体系结构上安装ONNX Runtime

除了可以在许多硬件配置上运行之外,还可以从大多数流行的编程语言中调用运行时。本文的目的是展示如何在Java中使用ONNX Runtime。我将展示如何安装onnxruntime软件包。安装ONNX Runtime后,我会将先前导出的MNIST模型加载到ONNX Runtime中,并使用它进行预测。

安装和导入ONNX运行时

在使用ONNX运行时之前,您需要向构建工具中添加适当的依赖项。Maven资源库是为各种工具(包括MavenGradle)设置ONNX运行时的良好来源。要在具有默认CPUx64架构上使用运行时,请参考下面的链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform

要在带GPUx64架构上使用运行时,请使用以下链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform-gpu

一旦安装了运行时,就可以使用如下所示的import语句将其导入到Java代码文件中。引入TensorProto工具的import语句将帮助我们为ONNX模型创建输入,也将有助于解释ONNX模型的输出(预测)。

import ai.onnxruntime.OnnxMl.TensorProto;
import ai.onnxruntime.OnnxMl.TensorProto.DataType;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;

载入ONNX模型

以下代码段显示了如何将ONNX模型加载到以Java运行的ONNX Runtime中。此代码创建可用于进行预测的会话对象。这里使用的模型是从PyTorch导出的ONNX模型。

这里有几件事值得注意。首先,您需要查询会话以获取其输入。这是使用会话的getInputInfo方法完成的。我们的MNIST模型只有一个输入参数:784个浮点数组,代表MNIST数据集中的一张图像。如果您的模型有多个输入参数,则InputMetadata的每个参数都有一个条目。

Utilities.LoadTensorData();
String modelPath = "pytorch_mnist.onnx";

try (OrtSession session = env.createSession(modelPath, options)) {
   Map<String, NodeInfo> inputMetaMap = session.getInputInfo();
   Map<String, OnnxTensor> container = new HashMap<>();
   NodeInfo inputMeta = inputMetaMap.values().iterator().next();

   float[] inputData = Utilities.ImageData[imageIndex];
   string label = Utilities.ImageLabels[imageIndex];
   System.out.println("Selected image is the number: " + label);

   // this is the data for only one input tensor for this model
   Object tensorData =
            OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape());
   OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorData);
   container.put(inputMeta.getName(), inputTensor);

   // Run code omitted for brevity.

}

 

上面的代码未显示用于读取原始MNIST图像并将每个图像转换为784个浮点数组的实用程序。还可以从MNIST数据集中读取每个图像的标签,以便可以确定预测的准确性。该代码是标准Java代码,但仍然鼓励您检出并使用它。如果您需要读入与MNIST数据集相似的图像,它将节省您的时间。

使用ONNX运行时进行预测

以下函数说明了如何使用在加载ONNX模型时创建的ONNX会话。

try (OrtSession session = env.createSession(modelPath, options)) {

   // Load code not shown for brevity.

   // Run the inference
   try (OrtSession.Result results = session.run(container)) {

      // Only iterates once
      for (Map.Entry<String, OnnxValue> r : results) {
         OnnxValue resultValue = r.getValue();
         OnnxTensor resultTensor = (OnnxTensor) resultValue;
         resultTensor.getValue()
         System.out.println("Output Name: {0}", r.Name);
         int prediction = MaxProbability(resultTensor);
         System.out.println("Prediction: " + prediction.ToString());
	}
   }
}

大多数神经网络不会直接返回预测。它们返回每个输出类的概率列表。对于我们的MNIST模型,每个图像的返回值将是10个概率的列表。可能性最高的条目是预测。您可以做一个有趣的测试,将ONNX模型在创建模型的框架中运行时返回的概率与从原始模型返回的概率进行比较。理想情况下,模型格式和运行时的更改不应更改所产生的任何概率。这将使每当模型发生更改时都可以运行良好的单元测试。

摘要和后续步骤

在本文中,我简要介绍了ONNX运行时和ONNX格式。然后,我展示了如何在ONNX运行时中使用Java加载和运行ONNX模型。

本文的代码示例包含一个工作的Console应用程序,该应用程序演示了此处显示的所有技术。此代码示例是Github存储库的一部分,该存储库探讨了使用神经网络预测MNIST数据集中发现的数字的方法。具体来说,有一些示例显示了如何在KerasPyTorchTensorFlow 1.0TensorFlow 2.0中创建神经网络。

如果您想了解有关导出为ONNX格式和使用ONNX Runtime的更多信息,请查阅本系列的其他文章。

参考文献

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值