[Python] 如何把scikit-learn的线性回归模型导出为onnx格式,并使用onnx模型文件进行预测

本文介绍了Scikit-learn,一个流行的Python机器学习库,以及ONNX如何促进不同深度学习框架之间的模型互操作。此外,还讲解了如何使用skl2onnx将Scikit-learn模型转换为ONNX格式,以便于跨框架部署和预测。
摘要由CSDN通过智能技术生成

什么是Scikit-learn?

Scikit-learn是一个用于Python编程语言的机器学习库。它提供了各种监督和无监督学习算法,包括分类、回归、聚类、降维等。Scikit-learn易于使用且功能强大,可以处理大型数据集,并且具有很好的可扩展性。它还提供了许多方便的工具,如数据预处理、模型选择、评估和可视化等。Scikit-learn是许多机器学习项目中使用的首选库之一。

什么是ONNX?

ONNX(Open Neural Network Exchange)是一个开放的生态系统,旨在使不同的深度学习框架之间能够互操作。它定义了一个通用的模型表示格式,使得在不同的深度学习框架之间进行模型转换和部署变得更加容易。ONNX模型可以被多种工具和平台所支持,包括ONNX Runtime、TensorFlow、PyTorch、Caffe2等。通过使用ONNX,开发者可以轻松地将一个深度学习模型转换为另一个框架所需的格式,从而实现模型的重用和加速。

什么是ONNX Runtime?

ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models. For more information on ONNX Runtime, please see aka.ms/onnxruntime or the Github project.

ONNX Runtime | Home

GitHub - microsoft/onnxruntime: ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

什么是skl2onnx库?

skl2onnx · PyPI

sklearn-onnx 1.16.0 documentation 

 

安装onnx onnxruntime skl2onnx库

pip install onnx onnxruntime skl2onnx

scikit-learn的线性回归模型导出为onnx格式

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from skl2onnx import to_onnx

# 创建数据集
np.random.seed(0)
x = np.random.rand(100, 1)
x = x.astype(np.float32)
y = 2 + 3 * x + np.random.rand(100, 1)
y = y.astype(np.float32)

# 将数据集分为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)

# 训练模型
model = LinearRegression()
model.fit(x_train, y_train)

# 保存模型到文件
onx = to_onnx(model, x[:1])
with open("linear_regression_model.onnx", "wb") as f:
    f.write(onx.SerializeToString())

 使用onnxruntime来加载onnx模型进行预测

import onnxruntime as rt
from sklearn.metrics import mean_squared_error, r2_score

# 从文件中加载模型
sess = rt.InferenceSession("linear_regression_model.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
print('input_name:', input_name)
print('label_name:', label_name)
# 使用加载的模型进行预测
# # 当前模型只有一个输入和一个输出,所以我们只需要通过[0]获取第一个输出,即为预测值
y_pred = sess.run([label_name], {input_name: x_test.astype(np.float32)})[0]  
print(y_pred)

# 评估模型的性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print('均方误差:', mse)
print('R2分数:', r2)

输出结果: 

  • 36
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

老狼IT工作室

你的鼓励将是我创作的最大动力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值