SparkMLlib实现K-means
引言
之前写过一篇关于kmeans的博客,里面详细的介绍了关于K-means的的详细描述,用python是实现的,并且在最后附带数据,了解更改关于K-means的内容详看K-means
今天用scala语言中的spark,使用MLlib库来实现
依赖
<!--mllib依赖,我用的是scala是2.11, spark是2.2.0-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
注意
和python相比,一样是调函数调参,但喂给model的数据类型和python不同,python中的SKLearing库使用的是矩阵或者是DataFrame,在spark里边要求的data是RDD[Vector]类型
/**
* Trains a k-means model using specified parameters and the default values for unspecified.
* 源码中最简单的训练函数要求的是传递三个参数,分别是数据集、族数、迭代次数
*/
def train(
data: RDD[Vector],
k: Int,
maxIterations: Int): KMeansModel = {
train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
}
代码
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.rdd.RDD
object Kmeans {
def main(args: Array[String]): Unit = {
//模板代码,指定两个线程模拟在hadoop端的分布式
val conf = new SparkConf().setAppName("Kmeans").setMaster("local[2]")
val sc = new SparkContext(conf)
//加载数据
val data = sc.textFile("F:/mllib/kmeans/trainsdata")
//将数据切分成标志格式,并封装成linalg.Vector类型
val parsedData: RDD[linalg.Vector] = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
//迭代次数1000次、类簇的个数2个,进行模型训练形成数据模型
val numClusters = 4
val numIterations = 1000
//进行训练
val model = KMeans.train(parsedData, numClusters, numIterations)
//打印数据模型的中心点
println("四个中心的点:")
for (point <- model.clusterCenters) {
println(" " + point.toString)
}
//使用误差平方之和来评估数据模型,统计聚类错误的样本比例
val cost = model.computeCost(parsedData)
println("聚类错误的样本比例 = " + cost)
//对部分点做预测分类
println("点(-3 -3)所属族:" + model.predict(Vectors.dense("-3 -3".split(' ').map(_.toDouble))))
println("点(-2 3)所属族:" + model.predict(Vectors.dense("-2 3".split(' ').map(_.toDouble))))
println("点(3 3)所属族:" + model.predict(Vectors.dense("3 3".split(' ').map(_.toDouble))))
sc.stop()
}
}
运行结果
四个中心的点:
[-2.4615431500000002,2.78737555]
[-3.3823704500000007,-2.9473363000000004]
[2.6265298999999995,3.10868015]
[2.80293085,-2.7315146]
聚类错误的样本比例 = 149.95430467642632
点(-3 -3)所属族:1
点(-2 3)所属族:0
点(3 3)所属族:2