package com.xtf.demo.mllib
import com.hankcs.hanlp.dictionary.stopword.CoreStopWordDictionary
import com.hankcs.hanlp.tokenizer.StandardTokenizer
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.Word2Vec
import org.apache.spark.ml.linalg.{DenseVector, Vectors}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{Column, SparkSession}
import breeze.linalg.Vector
import org.apache.spark.ml.clustering.KMeans
import scala.collection.JavaConversions._
import org.apache.spark.sql.functions._
import scala.collection.mutable
/**
* 根据文章标题及正文聚类,主题分析
* 1、hanlp分词
* 2、Word2Vec模型训练
* 3、Kmeans模型训练
* 4、Kmeans模型预测
*/
object KmeansCluster {
Logger.getLogger("org").setLevel(Level.ERROR)
val segmentUDF: UserDefinedFunction = udf((str1: String, str2: String) => textSegment(str1 + str2))
def textSegment(str: String): Array[String] = {
val list = StandardTokenizer.segment(str)
CoreStopWordDictionary.apply(list)
list.map(_.word).toArray
}
case class VectorGen(docid: String, vec: DenseVector)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("test")
.master("local[*]")
.getOrCreate()
val dataDF = spark.read.json("./dd1.txt")
.select("docid", "title", "text")
dataDF.show(false)
import spark.implicits._
//method one
// val segmentDF = dataDF.select("title").map{ x =>
// val text = x.getAs[String]("title")
// textSegment(text)
// }
// segmentDF.show(false)
//method two使用UDF函数分词
val udfDF = dataDF.select("docid", "title", "text")
.withColumn("words", segmentUDF(col("title"), col("text")))
udfDF.show(false)
//创建Word2Vec
val word2Vec = new Word2Vec()
.setInputCol("words")
.setOutputCol("result")
.setVectorSize(10)
.setMinCount(0)
//模型训练
val word2VecModel = word2Vec.fit(udfDF.select("words"))
//获取所有词及向量
val vectors = word2VecModel.getVectors
vectors.show(false)
//获取词及向量总集
val vectorsMap = vectors.collect().map{row =>
(row.getAs[String](0), row.getAs[DenseVector](1).toArray)
}.toMap
println(vectorsMap.take(10))
val vectorsDF = udfDF.rdd.map{row =>
val docid = row.getAs[String](0)
val words = row.getAs[mutable.WrappedArray[String]](3)
val vec = words.map(x => {
val aa = vectorsMap.getOrElse(x, Array.fill[Double](0)(0.0))
Vector.apply(aa)
}).reduce(_ + _)
(docid, vec)
}.map{x =>
VectorGen(x._1, Vectors.dense(x._2.toArray).toDense)
}.toDF()
vectorsDF.show(false)
val kMeans = new KMeans()
.setK(5)
.setFeaturesCol("vec")
.setMaxIter(50)
.setSeed(1l)
val kMeansModel = kMeans.fit(vectorsDF)
val dataFrame = kMeansModel.transform(vectorsDF)
dataFrame.show(false)
//使用不同的K,相同的迭代次数,理论上值越小,聚类效果越好
val cost = kMeansModel.computeCost(vectorsDF)
println(cost)
spark.stop()
}
}
控制台打印:
POM依赖:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>newsfeed</groupId>
<artifactId>newsfeed</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<scala.version>2.11.0</scala.version>
<spark.version>2.3.2</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.7.8</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
<!--<scope>provided</scope>-->
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.specs</groupId>
<artifactId>specs</artifactId>
<version>1.2.5</version>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.hbase/hbase -->
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase</artifactId>
<version>1.0.2</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-client</artifactId>
<version>1.0.2</version>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-server</artifactId>
<version>1.0.2</version>
</dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<args>
<arg>-target:jvm-1.5</arg>
</args>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.13</version>
<configuration>
<useFile>false</useFile>
<disableXmlReport>true</disableXmlReport>
<!-- If you have classpath issue like NoDefClassError,... -->
<!-- useManifestOnlyJar>false</useManifestOnlyJar -->
<includes>
<include>**/*Test.*</include>
<include>**/*Suite.*</include>
</includes>
</configuration>
</plugin>
</plugins>
</build>
</project>
关注微信公众号【飞哥大数据】,回复666 获取2022年100+公司面试真题,以及spark与flink面试题汇总