保存和跨平台使用sklearn机器学习库训练的模型的方法

       主要通过标准化预测模型编辑语言PMML进行格式化保存,以XML描述的文件保存在磁盘,实现数据挖掘模型的可移植性。

本文以如下简单的逻辑回归模型训练的例子进行相关说明,训练模型python代码如下:

from sklearn.linear_model import LogisticRegression

from sklearn import datasets

clf = LogisticRegression()   #逻辑回归算法

iris = datasets.load_iris()    #引入sklearn内置测试数据

X, y = iris.data, iris.target   #构造特征列与目标列

clf.fit(X, y)                           #模型训练

1.常用数据挖掘模型保存及加载预测的方法

(1)使用python内置的pickle 库以pkl文件保存模型,代码示例如下:

 import pickle 

   with open("model.pkl", "wb") as f:

      pickle.dump(clf, f)      #以二进制可写模式打开

   clf2 = pickle.loads(model.pkl )    #加载模型

   print(clf2.predict(X[0:1]))      #打印模型预测结果

(2)使用sklearn内置的joblib库以pkl文件保存模型,代码示例如下:

from sklearn.externals import joblib  

   joblib.dump(clf, "model/model.pkl")    #保存在model目录下

   clf3 = joblib.load("model/model.pkl")    #加载模型

   print(clf3.predict(X[0:1]))        #打印模型预测结果

(3)使用sklearn2pmml插件以pmml文件保存模型,代码示例如下:

from sklearn2pmml import PMMLPipeline, sklearn2pmml  

   pipeline = PMMLPipeline([("classifier", clf)])     #创建模型管道

   pipeline.fit(X, y)                   #训练模型

   sklearn2pmml(pipeline, "model/model.pmml")     #导出模型

   pipeline.load("model/model.pmml")                   #加载模型

   print(pipeline.predict("model/model.pmml"))    #打印模型预测结果

2.模型跨平台使用

为支持模型跨平台使用,需要将模型以PMML文件导出保存,然后供其它平台调用。这里以java调用模型使用为例进行介绍:

标准化预测模型编辑语言PMML提供了相关插件支持模型的跨平台调用,相关插件可在Maven公共仓库或GitHub搜索“jpmml”获取,java调用模型主要使用jpmml-pmml-evaluator与jpmml-pmml-evaluator-extension两个插件包,均可通过maven引入项目,最新版本插件调用模型pmml文件获取模型信息的主要部分代码示例如下:

File file = new File("model.pmml")

Evaluator evaluator = new LoadingModelEvaluatorBuilder().load(file).build();

evaluator.verify();    //模型自检

List<JSONObject> inputs = new ArrayList<>();      //模型输入列(特征)详情
for (InputField inputField : evaluator.getInputFields()) {
  JSONObject input = new JSONObject();
  input.put("dataType", inputField.getDataType().value().toUpperCase());
  input.put("name", inputField.getName().getValue());
  inputs.add(input);
}
List<JSONObject> outputs = new ArrayList<>();    //模型输出列详情
for (OutputField outputField : evaluator.getOutputFields()) {
  JSONObject ouput = new JSONObject();
  ouput.put("dataType", outputField.getDataType().value().toUpperCase());
  ouput.put("name", outputField.getName().getValue());
  ouput.put("value",outputField.getField().getValue().toString());
  outputs.add(ouput);
}
List<JSONObject> targets = new ArrayList<>();    //模型目标列详情
for (TargetField targetField : evaluator.getTargetFields()) {
  JSONObject target = new JSONObject();
  target.put("dataType", targetField.getDataType().value().toUpperCase());
  target.put("name", targetField.getName().getValue());
  targets.add(target);
}

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值