-
目标:在Spark环境下训练机器学习模型,并在Java环境下进行推测
-
过程1:在Spark环境下训练机器学习模型,导出训练模型为PMML文件
Code:
import java.io.File
import javax.xml.transform.stream.StreamResult
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.SparkSession
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder
object DecisionTreeClassifierDemo {
//屏蔽不必要的日志显示在终端上
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) //warn类信息不会显示,只显示error级别的
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
//spark2-submit --class DecisionTreeClassifierDemo ./DecisionTreeClassifierDemo.jar &> ./log/DecisionTreeClassifierDemo.log
//离线方式
val spark = SparkSession
.builder
.appName("DecisionTreeClassifier")
.master("local[4]")
.getOrCreate()
// 加载训练集
val appData = spark.read
.format("csv")
.options(Map(
"header" -> "true",
"inferSchema" -> "true"))
.load("appused2gender.csv")
appData.show(5)
appData.printSchema()
/*
+------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
| duration|use_number| duration_trend| duration_workday|use_number_workday| duration_weekend|use_number_weekend|gender|
+------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
|465.46666666666675| 156.65625|-0.9740957966764417| 443.9840277777778|145.04166666666666| 529.9145833333333| 191.5| 1|
| 711.4791666666666| 727.96875| 21.685239491691103| 648.4458333333334| 792.125| 900.5791666666668| 535.5| 0|
|148.50833333333333| 212.6| 17.145959595959596|122.58809523809524| 207.0| 208.9888888888889|225.66666666666666| 0|
| 582.4250000000001| 227.5| 41.51378260869564|357.47631578947363|127.36842105263158| 1437.23| 608.0| 0|
|383.83833333333337| 263.5| -4.783333333333332|294.41666666666663| 202.8|473.25999999999993| 324.2| 0|
+------------------+----------+-------------------+------------------+------------------+------------------+------------------+------+
*/
// 特征选择
val formula = new RFormula().setFormula("gender ~ .")
// 设置决策树分类器
val classifier = new DecisionTreeClassifier()
.setLabelCol(formula.getLabelCol)
.setFeaturesCol(formula.getFeaturesCol)
// 组合pipeline
val pipeline = new Pipeline().setStages(Array[PipelineStage](formula, classifier))
// 训练
val pipelineModel = pipeline.fit(appData)
val schema = appData.schema
schema.printTreeString()
println("schema: " + schema)
// 将pmml以流的形式输出到控制台
val pmml = new PMMLBuilder(schema, pipelineModel).build
JAXBUtil.marshalPMML(pmml, new StreamResult(System.out))
// 将pmml写到文件
new PMMLBuilder(schema, pipelineModel).buildFile(new File("Appused2genderDecisionTreeClassifier.pmml"))
}
}
IDEA配置xml文件如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xdb.spark2pmml</groupId>
<artifactId>Spark2PmmlDemo</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<spark.version>2.3.0</spark.version>
<scala.version>2.11</scala.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.jpmml/jpmml-sparkml -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-sparkml</artifactId>
<version>1.4.11</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<version>2.15.2</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.6.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19</version>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
</plugins>
</build>
</project>
输出PMML文件格式如下:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
<Header>
<Application name="JPMML-SparkML" version="1.4.11"/>
<Timestamp>2020-04-19T11:25:07Z</Timestamp>
</Header>
<DataDictionary>
<DataField name="gender" optype="categorical" dataType="integer">
<Value value="0"/>
<Value value="1"/>
</DataField>
<DataField name="duration" optype="continuous" dataType="double"/>
<DataField name="use_number" optype="continuous" dataType="double"/>
<DataField name="duration_trend" optype="continuous" dataType="double"/>
<DataField name="duration_workday" optype="continuous" dataType="double"/>
<DataField name="use_number_workday" optype="continuous" dataType="double"/>
<DataField name="duration_weekend" optype="continuous" dataType="double"/>
<DataField name="use_number_weekend" optype="continuous" dataType="double"/>
</DataDictionary>
<TreeModel functionName="classification" missingValueStrategy="nullPrediction">
<MiningSchema>
<MiningField name="gender" usageType="target"/>
<MiningField name="duration"/>
<MiningField name="use_number"/>
<MiningField name="duration_trend"/>
<MiningField name="duration_workday"/>
<MiningField name="use_number_workday"/>
<MiningField name="duration_weekend"/>
<MiningField name="use_number_weekend"/>
</MiningSchema>
<Output>
<OutputField name="pmml(prediction)" optype="categorical" dataType="integer" isFinalResult="false"/>
<OutputField name="prediction" optype="continuous" dataType="double" feature="transformedValue">
<MapValues outputColumn="data:output" dataType="double">
<FieldColumnPair field="pmml(prediction)" column="data:input"/>
<InlineTable>
<row>
<data:input>0</data:input>
<data:output>0</data:output>
</row>
<row>
<data:input>1</data:input>
<data:output>1</data:output>
</row>
</InlineTable>
</MapValues>
</OutputField>
<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
</Output>
<Node>
<True/>
<Node>
<SimplePredicate field="duration" operator="lessOrEqual" value="770.2519269776877"/>
<Node>
<SimplePredicate field="duration_weekend" operator="lessOrEqual" value="133.434375"/>
<Node>
<SimplePredicate field="duration_workday" operator="lessOrEqual" value="571.2881944444446"/>
<Node>
<SimplePredicate field="use_number_weekend" operator="lessOrEqual" value="467.5"/>
<Node score="0" recordCount="400">
<SimplePredicate field="use_number_workday" operator="lessOrEqual" value="486.6594202898551"/>
<ScoreDistribution value="0" recordCount="227.0"/>
<ScoreDistribution value="1" recordCount="173.0"/>
</Node>
<Node score="1" recordCount="8">
<True/>
<ScoreDistribution value="0" recordCount="2.0"/>
<ScoreDistribution value="1" recordCount="6.0"/>
</Node>
</Node>
<Node score="1" recordCount="4">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="4.0"/>
</Node>
</Node>
<Node>
<SimplePredicate field="duration_trend" operator="lessOrEqual" value="-73.75824404761904"/>
<Node score="1" recordCount="1">
<SimplePredicate field="duration" operator="lessOrEqual" value="539.9891025641025"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="1.0"/>
</Node>
<Node score="0" recordCount="4">
<True/>
<ScoreDistribution value="0" recordCount="4.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
</Node>
</Node>
<Node score="0" recordCount="17">
<True/>
<ScoreDistribution value="0" recordCount="17.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
</Node>
</Node>
<Node>
<SimplePredicate field="duration_trend" operator="lessOrEqual" value="20.406185850439883"/>
<Node>
<SimplePredicate field="duration_trend" operator="lessOrEqual" value="-16.603133718344537"/>
<Node score="0" recordCount="162">
<SimplePredicate field="use_number" operator="lessOrEqual" value="326.46770833333335"/>
<ScoreDistribution value="0" recordCount="98.0"/>
<ScoreDistribution value="1" recordCount="64.0"/>
</Node>
<Node score="1" recordCount="96">
<True/>
<ScoreDistribution value="0" recordCount="41.0"/>
<ScoreDistribution value="1" recordCount="55.0"/>
</Node>
</Node>
<Node score="0" recordCount="134">
<SimplePredicate field="duration_workday" operator="lessOrEqual" value="149.45138888888889"/>
<ScoreDistribution value="0" recordCount="73.0"/>
<ScoreDistribution value="1" recordCount="61.0"/>
</Node>
<Node score="1" recordCount="2247">
<True/>
<ScoreDistribution value="0" recordCount="930.0"/>
<ScoreDistribution value="1" recordCount="1317.0"/>
</Node>
</Node>
<Node>
<SimplePredicate field="duration_weekend" operator="lessOrEqual" value="925.04375"/>
<Node score="0" recordCount="26">
<SimplePredicate field="duration" operator="lessOrEqual" value="302.228125"/>
<ScoreDistribution value="0" recordCount="19.0"/>
<ScoreDistribution value="1" recordCount="7.0"/>
</Node>
<Node score="0" recordCount="155">
<True/>
<ScoreDistribution value="0" recordCount="82.0"/>
<ScoreDistribution value="1" recordCount="73.0"/>
</Node>
</Node>
<node>此处省略一堆node</node>
<True/>
<ScoreDistribution value="0" recordCount="12.0"/>
<ScoreDistribution value="1" recordCount="24.0"/>
</Node>
</Node>
</TreeModel>
</PMML>
过程2:在Java环境下加载PMML文件模型并做推测
Code:
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
public class PMML2EvaluatorDemo {
private Evaluator loadPmml(String filePath){
//加载pmml文件
PMML pmml = new PMML();
InputStream inputStream = null;
try {
inputStream = new FileInputStream(filePath);
} catch (IOException e) {
e.printStackTrace();
}
if(inputStream == null){
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException e1) {
e1.printStackTrace();
} catch (JAXBException e1) {
e1.printStackTrace();
}finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
pmml = null;
return evaluator;
}
private void showEvaluator(Evaluator evaluator) {
//输出模型的输入输出节点
List<InputField> inputFields = evaluator.getInputFields();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
System.out.println("inputFieldName:" + inputFieldName);
}
List<TargetField> targetFields = evaluator.getTargetFields();
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
System.out.println("targetFieldName:" + targetFieldName);
}
}
private int predictAppused2gender(Evaluator evaluator,double a, double b, double c, double d, double e, double f, double g) {
Map<String, Double> data = new HashMap<String, Double>();
data.put("duration", a);
data.put("use_number", b);
data.put("duration_trend", c);
data.put("duration_workday", d);
data.put("use_number_workday", e);
data.put("duration_weekend", f);
data.put("use_number_weekend", g);
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
Integer primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
//System.out.println(" computable.getResult(): "+ computable.getResult());
primitiveValue = (Integer) computable.getResult();
}
shuchu(data);
System.out.println("duration:" +a + " use_number:" + b + " duration_trend:" + c + " duration_workday:" + d +
" use_number_workday:" + e +" duration_weekend:" + f + " use_number_weekend:" + g +" --> gender:" + primitiveValue);
return primitiveValue;
}
public void shuchu(Map<String, Double> data) {
Iterator iter = data.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
Object key = entry.getKey();
Object value = entry.getValue();
System.out.println(key + ":" + value);
}
}
public static void main(String args[]){
PMML2EvaluatorDemo demo = new PMML2EvaluatorDemo();
// 加载保存在pmml中的模型
String modelPath = "C:\\Users\\11103378\\IdeaProjects\\Spark\\SparkModel2PMML\\Appused2genderMultilayerPerceptronClassifier.pmml";
Evaluator model = demo.loadPmml(modelPath);
demo.showEvaluator(model);
// 单条记录推测
demo.predictAppused2gender(model,465.4,156.6,-0.9,443.9,145.0,529.9,191.5);
}
/*
duration:465.4
use_number_workday:145.0
duration_trend:-0.9
duration_weekend:529.9
use_number_weekend:191.5
use_number:156.6
duration_workday:443.9
duration:465.4 use_number:156.6 duration_trend:-0.9 duration_workday:443.9 use_number_workday:145.0 duration_weekend:529.9 use_number_weekend:191.5 --> gender:1
*/
}
配置pom文件如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xdb.spark</groupId>
<artifactId>LoadPmmlDemo</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<java.version>1.8</java.version>
<pmml-model.version>1.4.2</pmml-model.version>
</properties>
<dependencies>
<!--pmml模型-->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>${pmml-model.version}</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.2</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.2</version>
</dependency>
</dependencies>
</project>