spark在spring(Java)中的运用

spark在spring(Java)中的运用

转载

在Java Web中使用Spark MLlib训练的模型
作者:xingoo
出处:http://www.cnblogs.com/xing901022

Spark MLlib之决策树(DecisioinTree)
作者:caiandyong
出处:https://blog.csdn.net/caiandyong/article/details/51831097

想法

问题:在假期的最后两天我突然对正在做的项目有了一个想法,我正在做的是spring的项目,我能不能把上学期刚学的spark MLlib的机器模型用在里面呢?(加点亮点的同时也能运用新知识)

总体实现

纵观从开始到实现,都有一篇文章不得不提
在Java Web中使用Spark MLlib训练的模型.
从这里开始我想到实现的方法是

  1. 用sacla语言训练好模型
  2. 使用PMML将模型导出
  3. 在需要的java环节中直接使用导出的模型

导出模型

因为实现过scala在IDEA训练模型,所以我直接开始了解如何将训练好的模型使用PMML导出。
我在spark决策树的基础上准备把他导出
使用的这个例子:Spark MLlib之决策树(DecisioinTree).
结果就是一直报错,发现走不通
走了很多弯路,因为资料太少也无法借鉴
最后只能去去看官方API了,看着看着官网说我下载的spark就有例子。。。
就在:examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala
代码如下:(数据也在spark的data文件夹下)

import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data
val data = sc.textFile("data/mllib/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache()

// Cluster the data into two classes using KMeans
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(parsedData, numClusters, numIterations)

// Export to PMML to a String in PMML format
println(s"PMML Model:\n ${clusters.toPMML}")

// Export the model to a local file in PMML format
clusters.toPMML("/tmp/kmeans.xml")

// Export the model to a directory on a distributed file system in PMML format
clusters.toPMML(sc, "/tmp/kmeans")

// Export the model to the OutputStream in PMML format
clusters.toPMML(System.out)

在java环境使用该模型

这时候就需要一开始的
在Java Web中使用Spark MLlib训练的模型
根据博主的使用方法就可以在java环境使用了

注意事项

导包很重要!
因为导包导的太多导致冲突报错就是过程中最大的问题
我在导出模型的时候依赖是:

<properties>
            <spark.version>2.1.0</spark.version>
            <scala.version>2.11</scala.version>
        </properties>
        <repositories>
            <repository>
                <id>nexus-aliyun</id>
                <name>Nexus aliyun</name>
                <url>http://maven.aliyun.com/nexus/content/groups/public</url>
            </repository>
        </repositories>
        <dependencies>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-core_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-streaming_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-hive_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-mllib_${scala.version}</artifactId>
                <version>${spark.version}</version>
                <exclusions>
                    <exclusion>
                        <groupId>org.jpmml</groupId>
                        <artifactId>pmml-model</artifactId>
                    </exclusion>
                </exclusions>
            </dependency>

            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_2.11</artifactId>
                <version>2.1.1</version>
            </dependency>

            <dependency>
                <groupId>org.jpmml</groupId>
                <artifactId>jpmml-evaluator-spark</artifactId>
                <version>1.2.2</version>
            </dependency>

            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-mllib_2.11</artifactId>
                <version>${spark.version}</version>
                <!--<scope>runtime</scope>-->
            </dependency>

然后在使用模型的依赖是:

<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator</artifactId>
        <version>1.4.3</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator-extension -->
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator-extension</artifactId>
        <version>1.4.3</version>
    </dependency>

(因为不知道哪里冲突所以就开了两个项目。。。一定要分开导入依赖)

输出结果

数据:2.txt

0,32 1 2 0
1,27 1 1 1
1,29 1 1 0
1,25 1 2 1
0,23 0 2 1

数据:3.txt

0,32 1 1 0
0,25 1 2 0
1,29 1 2 1
1,24 1 1 0
0,31 1 1 0
1,35 1 2 1
0,30 0 1 0
0,31 1 1 0
1,30 1 2 1
1,21 1 1 0
0,21 1 2 0
1,21 1 2 1
0,29 0 2 1
0,29 1 0 1
0,29 0 2 1
1,30 1 1 0

导出模型代码:

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint

/**
 * @Author Song
 * @Date 2021/3/4 9:13
 * @Version 1.0
 */
object demo4 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("demo3")
    val sc = new SparkContext(conf)
    // $example on$
    // Load and parse the data
    val data = sc.textFile("D:\\111TEST\\data\\3.txt")
    val parsedData = data.map(line => LabeledPoint.parse(line))//.cache()
    val testData=sc.textFile("D:\\111TEST\\data\\2.txt")
    val parsedData2 = testData.map(line => LabeledPoint.parse(line))
    val model = new LogisticRegressionWithLBFGS()
      .setNumClasses(2)
      .run(parsedData)
    val predictionAndLabels = parsedData2.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val accuracy = metrics.accuracy
    println(s"Accuracy = $accuracy")


    // Export the model to a local file in PMML format
    model.toPMML("D:\\111TEST\\data2\\simple.xml")

    // Export the model to a directory on a distributed file system in PMML format
    model.toPMML(sc, "D:\\111TEST\\data2\\simple")

    // Export the model to the OutputStream in PMML format
    model.toPMML(System.out)
    // $example off$
    sc.stop()
  }
}

导入模型代码:

/**
 * @Author Song
 * @Date 2021/3/4 9:25
 * @Version 1.0
 */
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;

import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;


class PMMLDemo2 {
    private Evaluator loadPmml(){
        PMML pmml1 = new PMML();
        try(InputStream inputStream = new FileInputStream("D:\\111TEST\\data2\\simple.xml")){
            pmml1 = org.jpmml.model.PMMLUtil.unmarshal(inputStream);

        } catch (Exception e) {
            e.printStackTrace();
        }

        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        return modelEvaluatorFactory.newModelEvaluator(pmml1);

    }
    private Object predict(Evaluator evaluator,int a, int b, int c, int d) {
        Map<String, Integer> data = new HashMap<String, Integer>();
        data.put("field_0", a);
        data.put("field_1", b);
        data.put("field_2", c);
        data.put("field_3", d);
        List<InputField> inputFields = evaluator.getInputFields();
        //过模型的原始特征,从画像中获取数据,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);

        List<TargetField> targetFields = evaluator.getTargetFields();
        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();
        ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);
        System.out.println(a + " " + b + " " + c + " " + d + ":" + target);
        return target;
    }
    public static void main(String args[]){
        PMMLDemo2 demo = new PMMLDemo2();
        Evaluator model = demo.loadPmml();
        demo.predict(model,27,1,1,1);
        demo.predict(model,25,1,2,1);
        demo.predict(model,23,0,2,1);
    }
}
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值