Spark实现高斯朴素贝叶斯
import breeze.stats.distributions.Gaussian
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.stat.Summarizer.{
mean => summaryMean,
variance => summaryVar
}
import org.apache.spark.sql.functions.udf
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
// 数据加载
val irisData = spark.read
.option("header", true)
.option("inferSchema", true)
// .csv("F:\\DataSource\\person_naiveBayes.csv")
.csv("F:\\DataSource\\iris.csv")
val rolluped = irisData.rollup($"class").count()
// 样本量
val sampleSize = rolluped.where($"class".isNull).head().getAs[Long](1)
// 计算先验概率
val pprobMap = rolluped
.where($"class".isNotNull)
.withColumn("pprob", $"count" / sampleSize)
.collect()
.map(row => {
row.getAs[String]("class") -> row.getAs[Double]("pprob")
})
.toMap
val schema = irisData.schema
val fts = schema.filterNot(_.name == """class""").map(_.name).toArray
// 数据转换
val amountVectorAssembler: VectorAssembler = new VectorAssembler()
.setInputCols(fts)
.setOutputCol("features")
val ftsDF = amountVectorAssembler
.transform(irisData)
.select("class", "features")
// 聚合计算:计算特征均值向量和方差向量
val irisAggred = ftsDF
.groupBy($"class")
.agg(
summaryMean($"features") as "mfts",
summaryVar($"features") as "vfts"
)
val cprobs: Array[(Array[(Double, Double)], String)] = irisAggred
.collect()
.map(row => {
val cl = row.getAs[String]("class")
val mus = row.getAs[DenseVector]("mfts").toArray
val vars = row.getAs[DenseVector]("vfts").toArray
(mus.zip(vars), cl)
})
def pdf(x: Double, mu: Double, sigma2: Double) = {
Gaussian(mu, math.sqrt(sigma2)).pdf(x)
}
val predictUDF = udf((vec: DenseVector) => {
cprobs
.map(tp => {
val tuples: Array[((Double, Double), Double)] = tp._1.zip(vec.toArray)
val cp: Double = tuples.map {
case ((mu, sigma), x) => pdf(x, mu, sigma)
}.product
val pprob: Double = pprobMap.getOrElse(tp._2, 0)
(cp * pprob, tp._2)
})
.maxBy(_._1)
._2
})
val predictDF = ftsDF
.withColumn("predict", predictUDF($"features"))
predictDF.where($"class" =!= $"predict").show(truncate = false)
spark.stop()
}
+---------------+-----------------+---------------+
|class |features |predict |
+---------------+-----------------+---------------+
|Iris-versicolor|[6.9,3.1,4.9,1.5]|Iris-virginica |
|Iris-versicolor|[5.9,3.2,4.8,1.8]|Iris-virginica |
|Iris-versicolor|[6.7,3.0,5.0,1.7]|Iris-virginica |
|Iris-virginica |[4.9,2.5,4.5,1.7]|Iris-versicolor|
|Iris-virginica |[6.0,2.2,5.0,1.5]|Iris-versicolor|
|Iris-virginica |[6.3,2.8,5.1,1.5]|Iris-versicolor|
+---------------+-----------------+---------------+