Java解析pmml格式机器学习模型

背景

         在实际工程项目中,我们训练和迭代模型一般使用Python,因此它提供了强大的算法包和非常方便的数据处理工具,所以能够快速试验。但是,算法模型部署成服务,Java语言和其相应的框架就显得优势明显了。为了更好地结合Python 与 Java各自的优势,PMML能够作为中间媒介,将模型以.pmml格式导出,然后利用java语言进行解析和部署

PMML是什么?

        可以理解为类似于一个xml的文件格式,能够将机器学习模型以文件格式导出。并且,文件内容是对算法规则的描述,例如以下用Iris数据集训练的决策树模型

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
	<Header>
		<Application name="JPMML-SkLearn" version="1.5.34"/>
		<Timestamp>2020-03-24T06:07:44Z</Timestamp>
	</Header>
	<MiningBuildTask>
		<Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'))])</Extension>
	</MiningBuildTask>
	<DataDictionary>
		<DataField name="y" optype="categorical" dataType="integer">
			<Value value="0"/>
			<Value value="1"/>
			<Value value="2"/>
		</DataField>
		<DataField name="x2" optype="continuous" dataType="float"/>
		<DataField name="x3" optype="continuous" dataType="float"/>
		<DataField name="x4" optype="continuous" dataType="float"/>
	</DataDictionary>
	<TransformationDictionary/>
	<TreeModel functionName="classification" missingValueStrategy="nullPrediction">
		<MiningSchema>
			<MiningField name="y" usageType="target"/>
			<MiningField name="x3"/>
			<MiningField name="x4"/>
			<MiningField name="x2"/>
		</MiningSchema>
		<Output>
			<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
			<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
			<OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
		</Output>
		<LocalTransformations>
			<DerivedField name="double(x3)" optype="continuous" dataType="double">
				<FieldRef field="x3"/>
			</DerivedField>
			<DerivedField name="double(x4)" optype="continuous" dataType="double">
				<FieldRef field="x4"/>
			</DerivedField>
			<DerivedField name="double(x2)" optype="continuous" dataType="double">
				<FieldRef field="x2"/>
			</DerivedField>
		</LocalTransformations>
		<Node>
			<True/>
			<Node score="0" recordCount="50.0">
				<SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.449999988079071"/>
				<ScoreDistribution value="0" recordCount="50.0"/>
				<ScoreDistribution value="1" recordCount="0.0"/>
				<ScoreDistribution value="2" recordCount="0.0"/>
			</Node>
			<Node>
				<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.75"/>
				<Node>
					<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.950000047683716"/>
					<Node score="1" recordCount="47.0">
						<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.6500000357627869"/>
						<ScoreDistribution value="0" recordCount="0.0"/>
						<ScoreDistribution value="1" recordCount="47.0"/>
						<ScoreDistribution value="2" recordCount="0.0"/>
					</Node>
					<Node score="2" recordCount="1.0">
						<True/>
						<ScoreDistribution value="0" recordCount="0.0"/>
						<ScoreDistribution value="1" recordCount="0.0"/>
						<ScoreDistribution value="2" recordCount="1.0"/>
					</Node>
				</Node>
				<Node score="2" recordCount="3.0">
					<SimplePredicate field="double(x4)" operator="lessOrEqual" value="1.550000011920929"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="3.0"/>
				</Node>
				<Node score="1" recordCount="2.0">
					<SimplePredicate field="double(x3)" operator="lessOrEqual" value="5.450000047683716"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="2.0"/>
					<ScoreDistribution value="2" recordCount="0.0"/>
				</Node>
				<Node score="2" recordCount="1.0">
					<True/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="1.0"/>
				</Node>
			</Node>
			<Node>
				<SimplePredicate field="double(x3)" operator="lessOrEqual" value="4.8500001430511475"/>
				<Node score="2" recordCount="2.0">
					<SimplePredicate field="double(x2)" operator="lessOrEqual" value="3.100000023841858"/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="0.0"/>
					<ScoreDistribution value="2" recordCount="2.0"/>
				</Node>
				<Node score="1" recordCount="1.0">
					<True/>
					<ScoreDistribution value="0" recordCount="0.0"/>
					<ScoreDistribution value="1" recordCount="1.0"/>
					<ScoreDistribution value="2" recordCount="0.0"/>
				</Node>
			</Node>
			<Node score="2" recordCount="43.0">
				<True/>
				<ScoreDistribution value="0" recordCount="0.0"/>
				<ScoreDistribution value="1" recordCount="0.0"/>
				<ScoreDistribution value="2" recordCount="43.0"/>
			</Node>
		</Node>
	</TreeModel>
</PMML>

Sklearn 生成PMML文件

  • 安装Sklearn2pmml 
pip install --user --upgrade git+https://github.com/jpmml/sklearn2pmml.git
  • 利用sklearn 自带的决策树模型、iris数据集训练和导出一个决策树pmml文件
from sklearn2pmml import PMMLPipeline
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf = tree.DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris.data, iris.target)

# 导出为PMML
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "/Desktop/DecisionTreeIris.pmml", with_repr = True)
  • 新建java maven工程,增加如下依赖
<dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.1</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.1</version>
</dependency>

JAVA接口的两种输入形式

  • java解析决策树模型,并完成预测输出
  • 补充一段:可以将pmml当作字符串输入,然后利用字节流转换,实现输入字符串参数也满足条件
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by sanyin
 *  on 2020/03/24.
 */
public class PMMLDemo {
    private Evaluator loadPmml(){
        PMML pmml = new PMML();
        InputStream inputStream = null;
   
     //注释这段是可以将pmml当成字符串传参,封装接口的时候就不用传pmml文件路径了
     //   try {
    //        inputStream = new ByteArrayInputStream(pmml.getBytes("utf-8"));
    //    } catch (
     //           IOException e) {
     //       e.printStackTrace();
     //   }
        try {
            inputStream = new FileInputStream("/Users/hzp/Desktop/DecisionTreeIris.pmml");
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(inputStream == null){
            return null;
        }
        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e1) {
            e1.printStackTrace();
        }finally {
            //关闭输入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        return evaluator;
    }
    private int predict(Evaluator evaluator,float a, float b, float c, float d) {
        //输入特征赋值,iris数据类型是4维,数据维度顺序不能乱
        Map<String, Float> data = new HashMap<String, Float>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        List<InputField> inputFields = evaluator.getInputFields();
        //构造模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
        int primitiveValue = -1;
        if (targetFieldValue instanceof Computable) {
            Computable computable = (Computable) targetFieldValue;
            primitiveValue = (Integer)computable.getResult();
        }
//        System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
        return primitiveValue;
    }
    public static void main(String args[]){
        PMMLDemo demo = new PMMLDemo();
        Evaluator model = demo.loadPmml();
        System.out.println(demo.predict(model,5.1f,3.5f,1.4f, 0.2f));
        System.out.println(demo.predict(model,6.9f,	3.1f,	5.1f,	2.3f));

    }
}

注意事项

  1. 注意python的数据类型,java输入数据类型要与其一致
  2. 注意python训练模型的特征维数,java输入数据特征维度需要与其一致
  3. 其他的模型也适用,只要sklearn能输出pmml文件格式即可

参考

       https://cloud.tencent.com/developer/article/1178944

       https://yao544303.github.io/2018/07/11/sklearn-PMML/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值