ONNX Runtime库学习之InferenceSession.run函数
一、简介
InferenceSession.run
是 ONNX Runtime 库中的一个函数,用于执行已加载的 ONNX 模型的推理。该函数接收输入数据,通过模型进行计算,并返回输出结果。它是进行模型推理的核心接口,支持多种输入输出方式,包括直接传递 NumPy 数组、使用 OrtValue
对象等。
二、语法和参数
语法:
InferenceSession.run(output_names, input_feed, run_options=None)
参数:
output_names
: 一个包含期望输出的节点名称的列表。input_feed
: 一个字典,键为输入节点名称,值为对应的输入数据。run_options
: 可选的RunOptions
对象,用于控制推理的行为。
返回值:
返回一个列表,包含了指定输出节点的推理结果。每个结果都是一个 NumPy 数组。
三、实例
3.1 基本推理
- 代码:
import onnxruntime as ort
import numpy as np
# 创建 InferenceSession 实例
session = ort.InferenceSession("model.onnx")
# 准备输入数据
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 执行推理
outputs = session.run(None, {input_name: input_data})
- 输出:
[array([[...], [...], [...]], dtype=float32)]
3.2 使用 RunOptions 控制推理
- 代码:
import onnxruntime as ort
import numpy as np
# 创建 InferenceSession 实例
session = ort.InferenceSession("model.onnx")
# 准备输入数据
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 创建 RunOptions 实例
run_options = ort.RunOptions()
run_options.use_deterministic_compute = True
# 执行推理
outputs = session.run(None, {input_name: input_data}, run_options=run_options)
- 输出:
[array([[...], [...], [...]], dtype=float32)]
四、注意事项
- 输入数据的维度和类型必须与模型的输入要求相匹配。
output_names
参数可以是 None,这表示返回所有输出节点的结果。run_options
是可选的,但可以用来控制推理过程中的某些行为,例如是否使用确定性计算。- 如果模型有多个输出,返回的列表将包含每个输出节点的结果。
- 在处理大型模型或数据时,应注意内存使用情况,以避免内存不足的问题。