用PMML实现机器学习模型的跨平台上线

参考:用PMML实现机器学习模型的跨平台上线

  • 目标:在Spark环境下训练机器学习模型,并在Java环境下进行推测

  • 过程1:在Spark环境下训练机器学习模型,导出训练模型为PMML文件

Code:

import java.io.File

import javax.xml.transform.stream.StreamResult
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.SparkSession
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder


object DecisionTreeClassifierDemo {
  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) //warn类信息不会显示,只显示error级别的
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
  def main(args: Array[String]): Unit = {
    //spark2-submit --class DecisionTreeClassifierDemo ./DecisionTreeClassifierDemo.jar &> ./log/DecisionTreeClassifierDemo.log

    //离线方式
    val spark = SparkSession
      .builder
      .appName("DecisionTreeClassifier")
      .master("local[4]")
      .getOrCreate()

    // 加载训练集
    val appData = spark.read
      .format("csv")
      .options(Map(
        "header" -> "true",
        "inferSchema" -> "true"))
      .load("appused2gender.csv")

    appData.show(5)
    appData.printSchema()
    /*    
        +------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
        |          duration|use_number|     duration_trend|  duration_workday|use_number_workday|  duration_weekend|use_number_weekend|gender|
        +------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
        |465.46666666666675| 156.65625|-0.9740957966764417| 443.9840277777778|145.04166666666666| 529.9145833333333|             191.5|     1|
          | 711.4791666666666| 727.96875| 21.685239491691103| 648.4458333333334|           792.125| 900.5791666666668|             535.5|     0|
          |148.50833333333333|     212.6| 17.145959595959596|122.58809523809524|             207.0| 208.9888888888889|225.66666666666666|     0|
          | 582.4250000000001|     227.5|  41.51378260869564|357.47631578947363|127.36842105263158|           1437.23|             608.0|     0|
          |383.83833333333337|     263.5| -4.783333333333332|294.41666666666663|             202.8|473.25999999999993|             324.2|     0|
          +------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
    */

    // 特征选择
    val formula = new RFormula().setFormula("gender ~ .")

    // 设置决策树分类器
    val classifier = new DecisionTreeClassifier()
      .setLabelCol(formula.getLabelCol)
      .setFeaturesCol(formula.getFeaturesCol)

    // 组合pipeline
    val pipeline = new Pipeline().setStages(Array[PipelineStage](formula, classifier))

    // 训练
    val pipelineModel = pipeline.fit(appData)

    val schema = appData.schema
      schema.printTreeString()
    println("schema: " + schema)

    // 将pmml以流的形式输出到控制台
    val pmml = new PMMLBuilder(schema, pipelineModel).build
    JAXBUtil.marshalPMML(pmml, new StreamResult(System.out))

    // 将pmml写到文件
    new PMMLBuilder(schema, pipelineModel).buildFile(new File("Appused2genderDecisionTreeClassifier.pmml"))
  }
}

IDEA配置xml文件如下:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xdb.spark2pmml</groupId>
    <artifactId>Spark2PmmlDemo</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <spark.version>2.3.0</spark.version>
        <scala.version>2.11</scala.version>
    </properties>


    <dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.jpmml/jpmml-sparkml -->
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>jpmml-sparkml</artifactId>
            <version>1.4.11</version>
        </dependency>
    </dependencies>

    <build>
        <plugins>

            <plugin>
                <groupId>org.scala-tools</groupId>
                <artifactId>maven-scala-plugin</artifactId>
                <version>2.15.2</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>compile</goal>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.6.0</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-surefire-plugin</artifactId>
                <version>2.19</version>
                <configuration>
                    <skip>true</skip>
                </configuration>
            </plugin>

        </plugins>
    </build>
</project>

输出PMML文件格式如下:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
	<Header>
		<Application name="JPMML-SparkML" version="1.4.11"/>
		<Timestamp>2020-04-19T11:25:07Z</Timestamp>
	</Header>
	<DataDictionary>
		<DataField name="gender" optype="categorical" dataType="integer">
			<Value value="0"/>
			<Value value="1"/>
		</DataField>
		<DataField name="duration" optype="continuous" dataType="double"/>
		<DataField name="use_number" optype="continuous" dataType="double"/>
		<DataField name="duration_trend" optype="continuous" dataType="double"/>
		<DataField name="duration_workday" optype="continuous" dataType="double"/>
		<DataField name="use_number_workday" optype="continuous" dataType="double"/>
		<DataField name="duration_weekend" optype="continuous" dataType="double"/>
		<DataField name="use_number_weekend" optype="continuous" dataType="double"/>
	</DataDictionary>
	<TreeModel functionName="classification" missingValueStrategy="nullPrediction">
		<MiningSchema>
			<MiningField name="gender" usageType="target"/>
			<MiningField name="duration"/>
			<MiningField name="use_number"/>
			<MiningField name="duration_trend"/>
			<MiningField name="duration_workday"/>
			<MiningField name="use_number_workday"/>
			<MiningField name="duration_weekend"/>
			<MiningField name="use_number_weekend"/>
		</MiningSchema>
		<Output>
			<OutputField name="pmml(prediction)" optype="categorical" dataType="integer" isFinalResult="false"/>
			<OutputField name="prediction" optype="continuous" dataType="double" feature="transformedValue">
				<MapValues outputColumn="data:output" dataType="double">
					<FieldColumnPair field="pmml(prediction)" column="data:input"/>
					<InlineTable>
						<row>
							<data:input>0</data:input>
							<data:output>0</data:output>
						</row>
						<row>
							<data:input>1</data:input>
							<data:output>1</data:output>
						</row>
					</InlineTable>
				</MapValues>
			</OutputField>
			<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
			<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
		</Output>
		<Node>
			<True/>
			<Node>
				<SimplePredicate field="duration" operator="lessOrEqual" value="770.2519269776877"/>
				<Node>
					<SimplePredicate field="duration_weekend" operator="lessOrEqual" value="133.434375"/>
					<Node>
						<SimplePredicate field="duration_workday" operator="lessOrEqual" value="571.2881944444446"/>
						<Node>
							<SimplePredicate field="use_number_weekend" operator="lessOrEqual" value="467.5"/>
							<Node score="0" recordCount="400">
								<SimplePredicate field="use_number_workday" operator="lessOrEqual" value="486.6594202898551"/>
								<ScoreDistribution value="0" recordCount="227.0"/>
								<ScoreDistribution value="1" recordCount="173.0"/>
							</Node>
							<Node score="1" recordCount="8">
								<True/>
								<ScoreDistribution value="0" recordCount="2.0"/>
								<ScoreDistribution value="1" recordCount="6.0"/>
							</Node>
						</Node>
						<Node score="1" recordCount="4">
							<True/>
							<ScoreDistribution value="0" recordCount="0.0"/>
							<ScoreDistribution value="1" recordCount="4.0"/>
						</Node>
					</Node>
					<Node>
						<SimplePredicate field="duration_trend" operator="lessOrEqual" value="-73.75824404761904"/>
						<Node score="1" recordCount="1">
							<SimplePredicate field="duration" operator="lessOrEqual" value="539.9891025641025"/>
							<ScoreDistribution value="0" recordCount="0.0"/>
							<ScoreDistribution value="1" recordCount="1.0"/>
						</Node>
						<Node score="0" recordCount="4">
							<True/>
							<ScoreDistribution value="0" recordCount="4.0"/>
							<ScoreDistribution value="1" recordCount="0.0"/>
						</Node>
					</Node>
					<Node score="0" recordCount="17">
						<True/>
						<ScoreDistribution value="0" recordCount="17.0"/>
						<ScoreDistribution value="1" recordCount="0.0"/>
					</Node>
				</Node>
				<Node>
					<SimplePredicate field="duration_trend" operator="lessOrEqual" value="20.406185850439883"/>
					<Node>
						<SimplePredicate field="duration_trend" operator="lessOrEqual" value="-16.603133718344537"/>
						<Node score="0" recordCount="162">
							<SimplePredicate field="use_number" operator="lessOrEqual" value="326.46770833333335"/>
							<ScoreDistribution value="0" recordCount="98.0"/>
							<ScoreDistribution value="1" recordCount="64.0"/>
						</Node>
						<Node score="1" recordCount="96">
							<True/>
							<ScoreDistribution value="0" recordCount="41.0"/>
							<ScoreDistribution value="1" recordCount="55.0"/>
						</Node>
					</Node>
					<Node score="0" recordCount="134">
						<SimplePredicate field="duration_workday" operator="lessOrEqual" value="149.45138888888889"/>
						<ScoreDistribution value="0" recordCount="73.0"/>
						<ScoreDistribution value="1" recordCount="61.0"/>
					</Node>
					<Node score="1" recordCount="2247">
						<True/>
						<ScoreDistribution value="0" recordCount="930.0"/>
						<ScoreDistribution value="1" recordCount="1317.0"/>
					</Node>
				</Node>
				<Node>
					<SimplePredicate field="duration_weekend" operator="lessOrEqual" value="925.04375"/>
					<Node score="0" recordCount="26">
						<SimplePredicate field="duration" operator="lessOrEqual" value="302.228125"/>
						<ScoreDistribution value="0" recordCount="19.0"/>
						<ScoreDistribution value="1" recordCount="7.0"/>
					</Node>
					<Node score="0" recordCount="155">
						<True/>
						<ScoreDistribution value="0" recordCount="82.0"/>
						<ScoreDistribution value="1" recordCount="73.0"/>
					</Node>
				</Node>
                   <node>此处省略一堆node</node>
				<True/>
				<ScoreDistribution value="0" recordCount="12.0"/>
				<ScoreDistribution value="1" recordCount="24.0"/>
			</Node>
		</Node>
	</TreeModel>
</PMML>

过程2:在Java环境下加载PMML文件模型并做推测

Code:

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;

public class PMML2EvaluatorDemo {
    private Evaluator loadPmml(String filePath){
        //加载pmml文件
        PMML pmml = new PMML();
        InputStream inputStream = null;
        try {
            inputStream = new FileInputStream(filePath);
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(inputStream == null){
            return null;
        }

        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e1) {
            e1.printStackTrace();
        }finally {
            //关闭输入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        pmml = null;
        return evaluator;
    }

    private void showEvaluator(Evaluator evaluator) {
        //输出模型的输入输出节点
        List<InputField> inputFields = evaluator.getInputFields();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            System.out.println("inputFieldName:" + inputFieldName);
        }
        List<TargetField> targetFields = evaluator.getTargetFields();
        for (TargetField targetField : targetFields) {
            FieldName targetFieldName = targetField.getName();
            System.out.println("targetFieldName:" + targetFieldName);
        }

    }

    private int predictAppused2gender(Evaluator evaluator,double a, double b, double c, double d, double e, double f, double g) {
        Map<String, Double> data = new HashMap<String, Double>();
        data.put("duration", a);
        data.put("use_number", b);
        data.put("duration_trend", c);
        data.put("duration_workday", d);
        data.put("use_number_workday", e);
        data.put("duration_weekend", f);
        data.put("use_number_weekend", g);
        List<InputField> inputFields = evaluator.getInputFields();

        //过模型的原始特征,从画像中获取数据,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);

        Integer primitiveValue = -1;
        if (targetFieldValue instanceof Computable) {
            Computable computable = (Computable) targetFieldValue;
            //System.out.println(" computable.getResult(): "+  computable.getResult());
            primitiveValue = (Integer) computable.getResult();
        }
        shuchu(data);
        System.out.println("duration:" +a + " use_number:" + b + " duration_trend:" + c + " duration_workday:" + d +
                " use_number_workday:" + e +" duration_weekend:" + f + " use_number_weekend:" + g +" --> gender:" + primitiveValue);
        return primitiveValue;
    }

    public void shuchu(Map<String, Double> data) {
        Iterator iter = data.entrySet().iterator();
        while (iter.hasNext()) {
            Map.Entry entry = (Map.Entry) iter.next();
            Object key = entry.getKey();
            Object value = entry.getValue();
            System.out.println(key + ":" + value);
        }
    }

    public static void main(String args[]){
        PMML2EvaluatorDemo demo = new PMML2EvaluatorDemo();
        // 加载保存在pmml中的模型
        String modelPath = "C:\\Users\\11103378\\IdeaProjects\\Spark\\SparkModel2PMML\\Appused2genderMultilayerPerceptronClassifier.pmml";
        Evaluator model = demo.loadPmml(modelPath);
        demo.showEvaluator(model);
        // 单条记录推测
        demo.predictAppused2gender(model,465.4,156.6,-0.9,443.9,145.0,529.9,191.5);
    }
    /*
    duration:465.4
    use_number_workday:145.0
    duration_trend:-0.9
    duration_weekend:529.9
    use_number_weekend:191.5
    use_number:156.6
    duration_workday:443.9
    duration:465.4 use_number:156.6 duration_trend:-0.9 duration_workday:443.9 use_number_workday:145.0 duration_weekend:529.9 use_number_weekend:191.5 --> gender:1
    */
}

配置pom文件如下:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xdb.spark</groupId>
    <artifactId>LoadPmmlDemo</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
        <java.version>1.8</java.version>
        <pmml-model.version>1.4.2</pmml-model.version>
    </properties>

    <dependencies>
        <!--pmml模型-->
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-model</artifactId>
            <version>${pmml-model.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.2</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.2</version>
        </dependency>
    </dependencies>
</project>

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值