package com.xtf.demo.mllib
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, IndexToString, StringIndexer, StringIndexerModel}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.ml.linalg.Vector
object PipelinesClassification {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().appName("test").master("local[*]").getOrCreate()
//构造训练数据集
val train = sparkSession.createDataFrame(Seq(
("aaa", Array("a", "b", "c", "d", "e", "spark"), "6"),
("bbb", Array("b", "d"), "7"),
("ccc", Array("spark", "f", "g", "h"), "6"),
("ddd", Array("hadoop", "mapreduce"), "7")
)).toDF("id", "words", "channel")
//频道标签转索引
val stringIndexer = new StringIndexer().setInputCol("channel").setOutputCol("label")
//词袋转特征向量
val hashingTF = new HashingTF().setNumFeatures(1000)
.setInputCol("words")
.setOutputCol("features")
val logisticRegression = new LogisticRegression().setMaxIter(10)
.setRegParam(0.001)
//创建流水线
val pipeline = new Pipeline()
.setStages(Array(stringIndexer, hashingTF, logisticRegression
))
val model = pipeline.fit(train)
//模型存储
model.write.overwrite().save("./spark-logistic-regression-model")
//测试数据集
val test = sparkSession.createDataFrame(Seq(
("eee", Array("spark", "i", "j", "k")),
("ggg", Array("l", "m", "n")),
("fff", Array("spark", "hadoop", "spark")),
("hhh", Array("apache", "hadoop"))
)).toDF("id", "words")
//预测结果
val frame = model.transform(test)
frame.show()
//索引转频道标签
val indexToString = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictionIndex")
.setLabels(model.stages(0).asInstanceOf[StringIndexerModel].labels)
val indexToStrDF = indexToString.transform(frame)
indexToStrDF.show()
indexToStrDF.select("id", "probability", "predictionIndex").collect().foreach{
case Row(id: String, probability: Vector, predictionIndex: String) =>
//probility向量的长度是分类的数目,每个位置的值是概率。
// 所以我们只要拿到最大的那个概率值所在的位置即可。
val prediction = probability.argmax
println(probability)
println(prediction)
println(predictionIndex)
}
sparkSession.stop()
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>newsfeed</groupId>
<artifactId>newsfeed</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<scala.version>2.11.0</scala.version>
<spark.version>2.1.0</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
<!--<scope>provided</scope>-->
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<args>
<arg>-target:jvm-1.5</arg>
</args>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.13</version>
<configuration>
<useFile>false</useFile>
<disableXmlReport>true</disableXmlReport>
<!-- If you have classpath issue like NoDefClassError,... -->
<!-- useManifestOnlyJar>false</useManifestOnlyJar -->
<includes>
<include>**/*Test.*</include>
<include>**/*Suite.*</include>
</includes>
</configuration>
</plugin>
</plugins>
</build>
</project>
控制台打印结果
+---+--------------------+--------------------+--------------------+--------------------+----------+
| id| words| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+--------------------+----------+
|eee| [spark, i, j, k]|(1000,[105,149,32...|[1.66090332274733...|[0.84035922612126...| 0.0|
|ggg| [l, m, n]|(1000,[6,638,655]...|[-1.6421889526563...|[0.16216743145233...| 1.0|
|fff|[spark, hadoop, s...|(1000,[105,181],[...|[2.59801421743934...|[0.93073366867024...| 0.0|
|hhh| [apache, hadoop]|(1000,[181,495],[...|[-4.0081703333680...|[0.01784246665557...| 1.0|
+---+--------------------+--------------------+--------------------+--------------------+----------+
+---+--------------------+--------------------+--------------------+--------------------+----------+---------------+
| id| words| features| rawPrediction| probability|prediction|predictionIndex|
+---+--------------------+--------------------+--------------------+--------------------+----------+---------------+
|eee| [spark, i, j, k]|(1000,[105,149,32...|[1.66090332274733...|[0.84035922612126...| 0.0| 6|
|ggg| [l, m, n]|(1000,[6,638,655]...|[-1.6421889526563...|[0.16216743145233...| 1.0| 7|
|fff|[spark, hadoop, s...|(1000,[105,181],[...|[2.59801421743934...|[0.93073366867024...| 0.0| 6|
|hhh| [apache, hadoop]|(1000,[181,495],[...|[-4.0081703333680...|[0.01784246665557...| 1.0| 7|
+---+--------------------+--------------------+--------------------+--------------------+----------+---------------+
[0.8403592261212601,0.15964077387873987]
0
6
[0.16216743145233842,0.8378325685476616]
1
7
[0.9307336686702409,0.06926633132975912]
0
6
[0.017842466655579713,0.9821575333444204]
1
7
关注微信公众号【飞哥大数据】,回复666 获取2022年100+公司面试真题,以及spark与flink面试题汇总