主要通过标准化预测模型编辑语言PMML进行格式化保存,以XML描述的文件保存在磁盘,实现数据挖掘模型的可移植性。
本文以如下简单的逻辑回归模型训练的例子进行相关说明,训练模型python代码如下:
from sklearn.linear_model import LogisticRegression
from sklearn import datasets
clf = LogisticRegression() #逻辑回归算法
iris = datasets.load_iris() #引入sklearn内置测试数据
X, y = iris.data, iris.target #构造特征列与目标列
clf.fit(X, y) #模型训练
1.常用数据挖掘模型保存及加载预测的方法
(1)使用python内置的pickle 库以pkl文件保存模型,代码示例如下:
import pickle
with open("model.pkl", "wb") as f:
pickle.dump(clf, f) #以二进制可写模式打开
clf2 = pickle.loads(model.pkl ) #加载模型
print(clf2.predict(X[0:1])) #打印模型预测结果
(2)使用sklearn内置的joblib库以pkl文件保存模型,代码示例如下:
from sklearn.externals import joblib
joblib.dump(clf, "model/model.pkl") #保存在model目录下
clf3 = joblib.load("model/model.pkl") #加载模型
print(clf3.predict(X[0:1])) #打印模型预测结果
(3)使用sklearn2pmml插件以pmml文件保存模型,代码示例如下:
from sklearn2pmml import PMMLPipeline, sklearn2pmml
pipeline = PMMLPipeline([("classifier", clf)]) #创建模型管道
pipeline.fit(X, y) #训练模型
sklearn2pmml(pipeline, "model/model.pmml") #导出模型
pipeline.load("model/model.pmml") #加载模型
print(pipeline.predict("model/model.pmml")) #打印模型预测结果
2.模型跨平台使用
为支持模型跨平台使用,需要将模型以PMML文件导出保存,然后供其它平台调用。这里以java调用模型使用为例进行介绍:
标准化预测模型编辑语言PMML提供了相关插件支持模型的跨平台调用,相关插件可在Maven公共仓库或GitHub搜索“jpmml”获取,java调用模型主要使用jpmml-pmml-evaluator与jpmml-pmml-evaluator-extension两个插件包,均可通过maven引入项目,最新版本插件调用模型pmml文件获取模型信息的主要部分代码示例如下:
File file = new File("model.pmml")
Evaluator evaluator = new LoadingModelEvaluatorBuilder().load(file).build();
evaluator.verify(); //模型自检
List<JSONObject> inputs = new ArrayList<>(); //模型输入列(特征)详情
for (InputField inputField : evaluator.getInputFields()) {
JSONObject input = new JSONObject();
input.put("dataType", inputField.getDataType().value().toUpperCase());
input.put("name", inputField.getName().getValue());
inputs.add(input);
}
List<JSONObject> outputs = new ArrayList<>(); //模型输出列详情
for (OutputField outputField : evaluator.getOutputFields()) {
JSONObject ouput = new JSONObject();
ouput.put("dataType", outputField.getDataType().value().toUpperCase());
ouput.put("name", outputField.getName().getValue());
ouput.put("value",outputField.getField().getValue().toString());
outputs.add(ouput);
}
List<JSONObject> targets = new ArrayList<>(); //模型目标列详情
for (TargetField targetField : evaluator.getTargetFields()) {
JSONObject target = new JSONObject();
target.put("dataType", targetField.getDataType().value().toUpperCase());
target.put("name", targetField.getName().getValue());
targets.add(target);
}