背景
在实际工程项目中,我们训练和迭代模型一般使用Python,因此它提供了强大的算法包和非常方便的数据处理工具,所以能够快速试验。但是,算法模型部署成服务,Java语言和其相应的框架就显得优势明显了。为了更好地结合Python 与 Java各自的优势,PMML能够作为中间媒介,将模型以.pmml格式导出,然后利用java语言进行解析和部署
PMML是什么?
可以理解为类似于一个xml的文件格式,能够将机器学习模型以文件格式导出。并且,文件内容是对算法规则的描述,例如以下用Iris数据集训练的决策树模型
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
<Header>
<Application name="JPMML-SkLearn" version="1.5.34"/>
<Timestamp>2020-03-24T06:07:44Z</Timestamp>
</Header>
<MiningBuildTask>
<Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=None,
splitter='best'))])</Extension>
</MiningBuildTask>
<DataDictionary>
<DataField name="y" optype="categorical" dataType="integer">
<Value value="0"/>
<Value value="1"/>
<Value value="2"/>
</DataField>
<DataField name="x2" optype="continuous" dataType="float"/>
<DataField name="x3" optype="continuous" dataType="float"/>
<DataField name="x4" optype="continuous" dataType="float"/>
</DataDictionary>
<TransformationDictionary/>
<TreeModel functionName="classification" missingValueStrategy="nullPrediction">
<MiningSchema>
<MiningField name="y" usageType="target"/>
<MiningField name="x3"/>
<MiningField name="x4"/>
<MiningField name="x2"/>
</MiningSchema>
<Output>
<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
<OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
</Output>
<LocalTransformations>
<DerivedField name="double(x3)" optype="continuous" dataType="double">
<FieldRef field="x3"/>
</DerivedField>
<DerivedField name="double(x4)" optype="continuous" dataType="double">
<FieldRef field="x4"/>
</DerivedField>
<DerivedField name="double(x2)" optype="continuous" dataType="double">
<FieldRef field="x2"/>
</DerivedField>
</LocalTransformations>
<Node>
<True/>
<Node score="0" recordCount="50.0">
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.449999988079071"/>
<ScoreDistribution value="0" recordCount="50.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
<Node>
<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.75"/>
<Node>
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.950000047683716"/>
<Node score="1" recordCount="47.0">
<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.6500000357627869"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="47.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
<Node score="2" recordCount="1.0">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="1.0"/>
</Node>
</Node>
<Node score="2" recordCount="3.0">
<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.550000011920929"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="3.0"/>
</Node>
<Node score="1" recordCount="2.0">
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="5.450000047683716"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="2.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
<Node score="2" recordCount="1.0">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="1.0"/>
</Node>
</Node>
<Node>
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.8500001430511475"/>
<Node score="2" recordCount="2.0">
<SimplePredicate field="double(x2)" operator="lessOrEqual" value="3.100000023841858"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="2.0"/>
</Node>
<Node score="1" recordCount="1.0">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="1.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
</Node>
<Node score="2" recordCount="43.0">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="43.0"/>
</Node>
</Node>
</TreeModel>
</PMML>
Sklearn 生成PMML文件
- 安装Sklearn2pmml
pip install --user --upgrade git+https://github.com/jpmml/sklearn2pmml.git
- 利用sklearn 自带的决策树模型、iris数据集训练和导出一个决策树pmml文件
from sklearn2pmml import PMMLPipeline
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris.data, iris.target)
# 导出为PMML
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "/Desktop/DecisionTreeIris.pmml", with_repr = True)
- 新建java maven工程,增加如下依赖
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>
JAVA接口的两种输入形式
- java解析决策树模型,并完成预测输出
- 补充一段:可以将pmml当作字符串输入,然后利用字节流转换,实现输入字符串参数也满足条件
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Created by sanyin
* on 2020/03/24.
*/
public class PMMLDemo {
private Evaluator loadPmml(){
PMML pmml = new PMML();
InputStream inputStream = null;
//注释这段是可以将pmml当成字符串传参,封装接口的时候就不用传pmml文件路径了
// try {
// inputStream = new ByteArrayInputStream(pmml.getBytes("utf-8"));
// } catch (
// IOException e) {
// e.printStackTrace();
// }
try {
inputStream = new FileInputStream("/Users/hzp/Desktop/DecisionTreeIris.pmml");
} catch (IOException e) {
e.printStackTrace();
}
if(inputStream == null){
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException e1) {
e1.printStackTrace();
} catch (JAXBException e1) {
e1.printStackTrace();
}finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
return evaluator;
}
private int predict(Evaluator evaluator,float a, float b, float c, float d) {
//输入特征赋值,iris数据类型是4维,数据维度顺序不能乱
Map<String, Float> data = new HashMap<String, Float>();
data.put("x1", a);
data.put("x2", b);
data.put("x3", c);
data.put("x4", d);
List<InputField> inputFields = evaluator.getInputFields();
//构造模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
int primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
primitiveValue = (Integer)computable.getResult();
}
// System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
return primitiveValue;
}
public static void main(String args[]){
PMMLDemo demo = new PMMLDemo();
Evaluator model = demo.loadPmml();
System.out.println(demo.predict(model,5.1f,3.5f,1.4f, 0.2f));
System.out.println(demo.predict(model,6.9f, 3.1f, 5.1f, 2.3f));
}
}
注意事项
- 注意python的数据类型,java输入数据类型要与其一致
- 注意python训练模型的特征维数,java输入数据特征维度需要与其一致
- 其他的模型也适用,只要sklearn能输出pmml文件格式即可
参考