spark版本2.1.3
maven设置
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.1.3</version>
<exclusions>
<exclusion>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-sparkml</artifactId>
<version>1.2.13</version>
</dependency>
spark-ml要去掉pmml-model依赖
模型训练
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.model.JAXBUtil;
import org.jpmml.sparkml.PMMLBuilder;
。。。
String[] features=new String[]{
"category", "future_day", "banner_min_time","banner_min_price",
"page_train", "page_flight", "page_bus", "page_transfer",
"start_end_distance", "total_transport", "high_railway_percent", "avg_time", "min_time",
"avg_price", "min_price",
"label_05060801", "label_05060701", "label_05060601", "label_02050601", "label_02050501", "label_02050401",
"is_match_category", "train_consumer_prefer", "flight_consumer_prefer", "bus_consumer_prefer"
};
VectorAssembler assembler = new VectorAssembler().setInputCols(features).setOutputCol("features");
RandomForestClassifier rf = new RandomForestClassifier()
.setLabelCol("isclick")
.setFeaturesCol("features")
.setMaxDepth(7)
.setNumTrees(60)
.setSeed(2018)
.setMinInstancesPerNode(1);
;
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{assembler, rf});
PipelineModel pipelineModel = pipeline.fit(trainData);
保存PipelineModel模型
pipelineModel.write().overwrite().save("D://model/random-forest");
hdfs保存方法也一样,换成hdfs路径就行了。
保存文件的PMML格式
StructType schema = trainData.schema();
PMML pmml = new PMMLBuilder(schema, pipelineModel).build();
saveToLocalFile(pmml);
// saveToHdfsFile(pmml);
private void saveToLocalFile(PMML pmml) {
String targetFile = "D://model/pmml/pipemodel";
try (FileOutputStream fis = new FileOutputStream(targetFile)) {
JAXBUtil.marshalPMML(pmml, new StreamResult(fis));
} catch (JAXBException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
HDFS路径保存为PMML,fileSystem为HDFS文件系统
private void saveToHdfsFile(PMML pmml) throws IOException {
String targetFile = "/data/twms/traffichuixing/model/stage/pmml/rf-pipepmml";
Path path = new Path(targetFile);
try(FSDataOutputStream fos = fileSystem.create(path)) {
JAXBUtil.marshalPMML(pmml, new StreamResult(fos));
} catch (JAXBException e) {
e.printStackTrace();
logger.error(e.getMessage());
}
}