ONNX-Scala 项目使用教程
1. 项目介绍
ONNX-Scala 是一个用于 Scala 3 的 ONNX(Open Neural Network eXchange)API 和后端,支持类型化的函数式深度学习和经典机器学习。该项目旨在提供一个高性能的 ONNX 模型推理接口,并支持跨平台(JVM、JavaScript、Scala Native)的开发。
ONNX-Scala 的主要特点包括:
- 类型化 API:提供类型安全的 API,确保在编译时捕获错误。
- 跨平台支持:支持 Scala JVM、Scala.js/JavaScript 和 Scala Native。
- 高性能:基于 ONNX Runtime,提供高效的模型推理性能。
- 全面支持 ONNX 操作符:支持大部分 ONNX 操作符,适用于各种深度学习和机器学习任务。
2. 项目快速启动
2.1 环境准备
首先,确保你已经安装了以下工具:
- Scala 3
- sbt(Scala 构建工具)
2.2 添加依赖
在你的项目 build.sbt
文件中添加以下依赖:
libraryDependencies += "org.emergent-order" %% "onnx-scala-backends" % "0.17.0"
2.3 下载模型文件
下载 SqueezeNet 模型文件,可以使用项目提供的脚本:
./get_models.sh
2.4 创建图像张量
以下代码展示了如何创建一个由像素值 42 组成的图像张量:
import java.nio.file.{Files, Paths}
import org.emergentorder.onnx.Tensors._
import org.emergentorder.onnx.Tensors.Tensor
import org.emergentorder.onnx.backends._
import org.emergentorder.compiletime._
import org.emergentorder.io.kjaer.compiletime._
val squeezenetBytes = Files.readAllBytes(Paths.get("squeezenet1_1_Opset18.onnx"))
val squeezenet = new ORTModelBackend(squeezenetBytes)
val data = Array.fill(1*3*224*224)(42f) // NCHW 格式的图像张量
val shape = 1 #: 3 #: 224 #: 224 #: SNil
val tensorShapeDenotation = "Batch" ##: "Channel" ##: "Height" ##: "Width" ##: TSNil
val tensorDenotation: String & Singleton = "Image"
val imageTens = Tensor(data, tensorDenotation, tensorShapeDenotation, shape)
// 或者使用简写形式,如果你不关心符号表示
val imageTensDefaultDenotations = Tensor(data, shape)
2.5 运行模型推理
接下来,运行 SqueezeNet 图像分类推理:
val out = squeezenet.fullModel[Float, "ImageNetClassification", "Batch" ##: "Class" ##: TSNil, 1 #: 1000 #: SNil](Tuple(imageTens))
// 输出形状
val outShape = out.shape.unsafeRunSync()
println(outShape) // 输出: Array(1, 1000, 1, 1)
// 输出数据
val outData = out.data.unsafeRunSync()
println(outData.indices.maxBy(outData)) // 输出最高概率的类别索引
3. 应用案例和最佳实践
3.1 图像分类
ONNX-Scala 可以用于图像分类任务,如使用 SqueezeNet 模型对图像进行分类。通过类型化的 API,可以确保输入和输出的形状和类型在编译时得到验证,减少运行时错误。
3.2 模型优化
在性能关键场景中,ONNX-Scala 提供了高效的模型推理接口。通过直接加载 ONNX 模型文件并传递给底层 ONNX 后端,可以充分利用 ONNX Runtime 的优化能力,实现高性能的模型推理。
3.3 跨平台开发
ONNX-Scala 支持跨平台开发,适用于 JVM、JavaScript 和 Scala Native。开发者可以根据项目需求选择合适的平台,实现代码的跨平台复用。
4. 典型生态项目
4.1 ONNX Runtime
ONNX Runtime 是一个跨平台的、高性能的机器学习推理引擎,支持多种硬件加速器。ONNX-Scala 基于 ONNX Runtime 构建,提供了高效的模型推理能力。
4.2 ScalaPB
ScalaPB 是一个用于生成 Scala 代码的 Protocol Buffers 编译器插件。ONNX-Scala 使用 ScalaPB 生成的代码来访问 ONNX 模型的 Protobuf 定义,从而实现类型安全的模型操作。
4.3 Spire
Spire 是一个强大的 Scala 数值计算库,提供了丰富的数值类型和抽象。ONNX-Scala 使用 Spire 来支持无符号整数、复数等数值类型,增强了数值计算的能力。
4.4 Dotty
Dotty 是 Scala 3 的编译器,引入了许多新特性,如联合类型、匹配类型和编译时单例操作。ONNX-Scala 充分利用这些特性,实现了类型安全的 API 设计。