聚类数据源下载地址 :http://download.csdn.net/detail/wguangliang/9595795
提供local单机测试代码,如下:
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.clustering.GaussianMixture
import org.apache.spark.mllib.linalg.Vectors
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.mllib.regression.LabeledPoint
import java.io.{FileWriter}
object GaussianMixtureTest {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
// if (args.length < 2) {
// println("usage: DenseGmmEM <input file> <k> [maxIterations]")
// } else {
// val maxIterations = if (args.length > 2) args(2).toInt else 100
// run(args(0), args(1).toInt, maxIterations)
// }
// var max = 0.0
// var maxIter = 0
// for(i<- List(1,10,20,30,40,50,60,70,80,90,100)) {
// val corr = run("C:\\Users\\qingjian\\Desktop\\sf.txt", 2, i)
// if(corr>max) {
// max = corr
// maxIter = i
// }
// }
// println(maxIter+":"+max)
// run("C:\\Users\\qingjian\\Desktop\\sf.txt",2,3)
run("C:\\Users\\qingjian\\Desktop\\Result_data.txt",2,4)
}
/**
* 输入文件路径,聚类个数[默认2个],最大迭代次数[默认100次]
*/
private def run(inputFile: String, k: Int, maxIterations: Int)= {
// val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example").setMaster("local[4]")
val ctx = new SparkContext(conf)
val data = ctx.textFile(inputFile).map { line =>
val split = line.trim.split('\t').map(_.toDouble)
Vectors.dense(split.init)
}.cache()
val dataWithLabel = ctx.textFile(inputFile).map { line =>
val split = line.trim.split('\t').map(_.toDouble)
LabeledPoint(split.last, Vectors.dense(split.init))
}.cache()
val clusters = new GaussianMixture()
.setK(k)
.setMaxIterations(maxIterations)
.run(data)
/* 显示分类概率
for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
println("The membership value of each vector to all mixture components (first <= 100):")
val membership = clusters.predictSoft(data)
membership.foreach { x =>
println(" " + x.mkString(","))
}*/
val prediction = clusters.predict(dataWithLabel.map(_.features))
val predictionWithLabel = dataWithLabel.zip(prediction)
///显示预测信息
predictionWithLabel.collect.map { x =>
println("预测" + x._1 + ":" + x._2)
}
val writer = new FileWriter("C:\\Users\\qingjian\\Desktop\\clustered_result.txt",false)
val sourceData = predictionWithLabel.map(x=>(x._1.label,1)).reduceByKey(_+_)
writer.append("原始数据\n")
sourceData.collect.foreach(x => writer.append(x._1 + "\t" + x._2 + "(" + x._2 / predictionWithLabel.count.toDouble + ")\n"))
val data1TopClusterLabel = predictionWithLabel.map(x=>(x._2,1)).reduceByKey(_+_).sortBy(_._2, false).first //作弊车的聚类标签
writer.append("\nClustered Instances 聚类结果["+data1TopClusterLabel._1+"代表正常数据的类]\n")
val clusteredInstance = predictionWithLabel.map(x => (x._2, 1)).reduceByKey(_ + _)
clusteredInstance.collect.foreach(x => writer.append(x._1 + "\t" + x._2 + "(" + x._2 / predictionWithLabel.count.toDouble + ")\n"))
writer.append("\n")
clusteredInstance.collect.foreach(x => writer.append(x._1 + "\t"))
writer.append("<--assigned to cluster\n")
val correctPrediction = predictionWithLabel.filter(x => x._1.label == x._2)
val errorPrediction = predictionWithLabel.filter(x => x._1.label != x._2)
for (i <- 0 until k) {
for (j <- 0 until k) {
if (i == j) {
writer.append(correctPrediction.filter(_._1.label.toInt == i).count + "\t")
} else {
writer.append(errorPrediction.filter(x => x._2 == j).count + "\t")
}
}
writer.append(""+i)
writer.append("\n")
}
val err = errorPrediction.count / predictionWithLabel.count.toDouble
if(data1TopClusterLabel._1.toInt!=0) {
writer.append("Incorrectly clusterd instances : " + errorPrediction.count + "\t" + (1-err))
} else {
writer.append("Incorrectly clusterd instances : " + errorPrediction.count + "\t" + err)
}
writer.close
ctx.stop()
1 - err //返回正确率
}
}
原始数据
0.0 514(0.514)
1.0 486(0.486)
Clustered Instances 聚类结果[0代表正常数据的类]
0 594(0.594)
1 406(0.406)
0 1 <--assigned to cluster
514 0 0
80 406 1
Incorrectly clusterd instances : 80 0.08