Spark-ml模型保存为PMML

spark版本2.1.3

maven设置

    <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.1.3</version>
            <exclusions>
                <exclusion>
                    <groupId>org.jpmml</groupId>
                    <artifactId>pmml-model</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>jpmml-sparkml</artifactId>
            <version>1.2.13</version>
        </dependency>

spark-ml要去掉pmml-model依赖


模型训练

import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;

import org.dmg.pmml.PMML;
import org.jpmml.model.JAXBUtil;
import org.jpmml.sparkml.PMMLBuilder;
。。。

        String[] features=new String[]{
                "category", "future_day", "banner_min_time","banner_min_price",
                "page_train", "page_flight", "page_bus", "page_transfer",
                "start_end_distance", "total_transport", "high_railway_percent", "avg_time", "min_time",
                "avg_price", "min_price",
                "label_05060801", "label_05060701", "label_05060601", "label_02050601", "label_02050501", "label_02050401",
                "is_match_category", "train_consumer_prefer", "flight_consumer_prefer", "bus_consumer_prefer"
        };
        VectorAssembler assembler = new VectorAssembler().setInputCols(features).setOutputCol("features");

        RandomForestClassifier rf = new RandomForestClassifier()
                .setLabelCol("isclick")
                .setFeaturesCol("features")
                .setMaxDepth(7)
                .setNumTrees(60)
                .setSeed(2018)
                .setMinInstancesPerNode(1);
        ;
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{assembler, rf});
        PipelineModel pipelineModel = pipeline.fit(trainData);

保存PipelineModel模型

 pipelineModel.write().overwrite().save("D://model/random-forest");

hdfs保存方法也一样,换成hdfs路径就行了。


保存文件的PMML格式

        StructType schema = trainData.schema();
        PMML pmml = new PMMLBuilder(schema, pipelineModel).build();
        saveToLocalFile(pmml);
//        saveToHdfsFile(pmml);
    private void saveToLocalFile(PMML pmml) {
        String targetFile = "D://model/pmml/pipemodel";
        try (FileOutputStream fis = new FileOutputStream(targetFile)) {
            JAXBUtil.marshalPMML(pmml, new StreamResult(fis));
        } catch (JAXBException e) {
            e.printStackTrace();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

HDFS路径保存为PMML,fileSystem为HDFS文件系统

    private void saveToHdfsFile(PMML pmml) throws IOException {
        String targetFile = "/data/twms/traffichuixing/model/stage/pmml/rf-pipepmml";
        Path path = new Path(targetFile);
        try(FSDataOutputStream fos = fileSystem.create(path)) {
            JAXBUtil.marshalPMML(pmml, new StreamResult(fos));
        } catch (JAXBException e) {
            e.printStackTrace();
            logger.error(e.getMessage());
        }
    }
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值