kmeans算法详解与spark实战

项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
欢迎大家star,留言,一起学习进步

1.标准kmeans算法

kmeans算法是实际中最常用的聚类算法,没有之一。kmeans算法的原理简单,实现起来不是很复杂,实际中使用的效果一般也不错,所以深受广大人民群众的喜爱。
kmeans算法的原理介绍方面的paper多如牛毛,而且理解起来确实也不是很复杂,这里使用wiki上的版本:
已知观测集 ( x 1 , x 2 , ⋯   , x n ) (x_1,x_2,\cdots,x_n) (x1,x2,,xn),其中每个观测都是一个 d d d维实矢量,kmeans聚类要把这 n n n个观测值划分到 k k k个集合中( k ≤ n k\le n kn),使得组内平方和(WCSS within-cluster sum of squares)最小。换句话说,它的目标是找到使得下式满足的聚类 S i S_i Si
arg min ⁡ S ∑ i = 1 k ∑ x ∈ S i ∥ x − μ i ∥ 2 {\displaystyle {\underset {\mathbf {S} }{\operatorname {arg\,min} }}\sum _{i=1}^{k}\sum _{\mathbf {x} \in S_{i}}\left\|\mathbf {x} -{\boldsymbol {\mu }}_{i}\right\|^{2}} Sargmini=1kxSixμi2
其中 μ i \mu_i μi S i S_i Si中所有点的均值。

标准kmeans算法的步骤一般如下:
1.先随机挑选k个初始聚类中心。
2.计算数据集中每个点到每个聚类中心的距离,然后将这个点分配到离该点最近的聚类中心。
3.重新计算每个类中所有点的坐标的平均值,并将得到的这个新的点作为新的聚类中心。
重复上面第2、3步,知道聚类中心点不再大范围移动(精度自己定义)或者迭代的总次数达到最大。

2.标准kmeans算法的优缺点

标准的kmeans算法的优缺点都很突出。这里挑几个最重要的点总结一下。

主要优点:

1.原理简单,易于理解。
2.实现简单
3.计算速度较快
4.聚类效果还不错。

主要缺点:

1.需要确定k值。
2.对初始中心点的选择敏感。
3.对异常值敏感,因为异常值很很大程度影响聚类中心的位置。
4.无法增量计算。这点在数据量大的时候尤为突出。

3.spark中对kmeans的优化

作为经典的聚类算法,一般的机器学习框架里都实现由kmeans,spark自然也不例外。前面我们已经讲了标准kmeans的流程以及优缺点,那么针对标准kmeans中的不足,spark里主要做了如下的优化:

1.选择合适的K值。

k的选择是kmeans算法的关键。Spark MLlib在KMeansModel里实现了computeCost方法,这个方法通过计算数据集中所有的点到最近中心点的平方和来衡量聚类的效果。一般来说,同样的迭代次数,这个cost值越小,说明聚类的效果越好。但在实际使用过程中,必须还要考虑聚类结果的可解释性,不能一味地选择cost值最小的那个k。比如我们如果考虑极限情况,如果数据集有n个点,如果令k=n,每个点都是聚类中心,每个类都只有一个点,此时cost值最小为0。但是这样的聚类结果显然是没有实际意义的。

2.选择合适的初始中心点

大部分迭代算法都对初始值很敏感,kmeans也是如此。spark MLlib在初始中心点的选择上,使用了k-means++的算法。想要详细了解k-means++的同学们,可以参考k-means++在wiki上的介绍:https://en.wikipedia.org/wiki/K-means%2B%2B。
kmeans++的基本思想是是初始中心店的相互距离尽可能远。为了实现这个初衷,采取如下步骤:
1.从初始数据集中随机选择一个点作为第一个聚类中心点。
2.计算数据集中所有点到最近一个中心点的距离D(x)并存在一个数组里,然后将所有这些距离加起来得到Sum(D(x))。
3.然后再取一个随机值,用权重的方式计算下一个中心点。具体的实现方法:先取一个在Sum(D(x))范围内的随机值,然后领Random -= D(x),直至Random <= 0,此时这个D(x)对应的点为下一个中心点。
4.重复2、3步直到k个聚类中心点被找出。
5.利用找出的k个聚类中心点,执行标准的kmeans算法。

算法的关键是在第三步。有两个小点需要说明:
1.不能直接取距离最大的那个点当中心店。因为这个点很可能是离群点。
2.这种取随机值的方法能保证距离最大的那个点被选中的概率最大。给大家举个很简单的例子:假设有四个点A、B、C、D,分别离最近中心的距离D(x)为1、2、3、4,那么Sum(D(x))=10。然后在[0,10]之间取一随机数假设为random,然后用random与D(x)依次相减,直至random<0为止。应该不难发现,D被选中的概率最大。

4.spark实战kmeans算法

前面讲了这么多理论,照例咱们需要实践一把。talk is cheap,show me the code!

1.准备数据

首先准备数据集。这里采用的数据集是UCI的一个数据集。数据地址http://archive.ics.uci.edu/ml/datasets/Wholesale+customers?cm_mc_uid=70889544450214522232748&cm_mc_sid_50200000=1469871598。UCI是一个常用的标准测试数据集,是搞ML与DM同学经常使用的数据集。关于该数据集的介绍,同学们可以去网页上查看。

将数据下载下来以后查看一把,第一行相当于是表头,是对数据的相关说明。将此行去掉,还剩440行。将前400行作为训练集,后40行作为测试集。

2.将代码run起来

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors

object KmeansTest {
  def main(args: Array[String]) {
  
    val conf = new
        SparkConf().setAppName("K-Means Clustering").setMaster("spark://your host:7077").setJars(List("your jar file"))
    val sc = new SparkContext(conf)

    val rawTrainingData = sc.textFile("file:///Users/lei.wang/data/data_training")
    val parsedTrainingData =
      rawTrainingData.filter(!isColumnNameLine(_)).map(line => {
        Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
      }).cache()

    // Cluster the data into two classes using KMeans

    val numClusters = 8
    val numIterations = 30
    val runTimes = 3
    var clusterIndex: Int = 0
    val clusters: KMeansModel =
      KMeans.train(parsedTrainingData, numClusters, numIterations, runTimes)

    println("Cluster Number:" + clusters.clusterCenters.length)

    println("Cluster Centers Information Overview:")
    clusters.clusterCenters.foreach(
      x => {
        println("Center Point of Cluster " + clusterIndex + ":")
        println(x)
        clusterIndex += 1
      })

    //begin to check which cluster each test data belongs to based on the clustering result

    val rawTestData = sc.textFile("file:///Users/lei.wang/data/data_test")
    val parsedTestData = rawTestData.map(line => {
      Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))

    })
    parsedTestData.collect().foreach(testDataLine => {
      val predictedClusterIndex:
      Int = clusters.predict(testDataLine)
      println("The data " + testDataLine.toString + " belongs to cluster " +
        predictedClusterIndex)
    })

    println("Spark MLlib K-means clustering test finished.")
  }

  private def isColumnNameLine(line: String): Boolean = {
    if (line != null && line.contains("Channel")) true
    else false
  }
}

在本地将代码跑起来以后,输出如下:

...
Cluster Number:8
Cluster Centers Information Overview:
Center Point of Cluster 0:
[1.103448275862069,2.5517241379310343,39491.1724137931,4220.6551724137935,5250.172413793103,4478.103448275862,870.9655172413793,2152.8275862068967]
Center Point of Cluster 1:
[2.0,2.4210526315789473,7905.894736842105,20288.052631578947,30969.263157894737,2002.0526315789473,14125.105263157893,3273.4736842105262]
Center Point of Cluster 2:
[1.0,2.5,34782.0,30367.0,16898.0,48701.5,755.5,26776.0]
Center Point of Cluster 3:
[1.2190476190476192,2.5142857142857147,17898.97142857143,3221.7904761904765,4525.866666666667,3639.419047619048,1061.152380952381,1609.9047619047622]
Center Point of Cluster 4:
[1.8987341772151898,2.481012658227848,4380.5822784810125,9389.151898734177,14524.556962025315,1508.4556962025317,6457.683544303797,1481.1772151898733]
Center Point of Cluster 5:
[1.0817610062893082,2.4716981132075473,5098.270440251573,2804.295597484277,3309.0943396226417,2416.37106918239,901.1886792452831,803.0062893081762]
Center Point of Cluster 6:
[1.0,3.0,85779.66666666666,12503.666666666666,12619.666666666666,13991.666666666666,2159.0,3958.0]
Center Point of Cluster 7:
[2.0,3.0,29862.5,53080.75,60015.75,3262.25,27942.25,3082.25]
...

此部分内容为聚类中心点相关信息,我们将k设为8,所以一共有8个中心点。

...
The data [1.0,3.0,4446.0,906.0,1238.0,3576.0,153.0,1014.0] belongs to cluster 5
The data [1.0,3.0,27167.0,2801.0,2128.0,13223.0,92.0,1902.0] belongs to cluster 3
The data [1.0,3.0,26539.0,4753.0,5091.0,220.0,10.0,340.0] belongs to cluster 3
The data [1.0,3.0,25606.0,11006.0,4604.0,127.0,632.0,288.0] belongs to cluster 3
The data [1.0,3.0,18073.0,4613.0,3444.0,4324.0,914.0,715.0] belongs to cluster 3
The data [1.0,3.0,6884.0,1046.0,1167.0,2069.0,593.0,378.0] belongs to cluster 5
The data [1.0,3.0,25066.0,5010.0,5026.0,9806.0,1092.0,960.0] belongs to cluster 3
The data [2.0,3.0,7362.0,12844.0,18683.0,2854.0,7883.0,553.0] belongs to cluster 4
The data [2.0,3.0,8257.0,3880.0,6407.0,1646.0,2730.0,344.0] belongs to cluster 5
The data [1.0,3.0,8708.0,3634.0,6100.0,2349.0,2123.0,5137.0] belongs to cluster 5
The data [1.0,3.0,6633.0,2096.0,4563.0,1389.0,1860.0,1892.0] belongs to cluster 5
The data [1.0,3.0,2126.0,3289.0,3281.0,1535.0,235.0,4365.0] belongs to cluster 5
The data [1.0,3.0,97.0,3605.0,12400.0,98.0,2970.0,62.0] belongs to cluster 4
The data [1.0,3.0,4983.0,4859.0,6633.0,17866.0,912.0,2435.0] belongs to cluster 5
The data [1.0,3.0,5969.0,1990.0,3417.0,5679.0,1135.0,290.0] belongs to cluster 5
The data [2.0,3.0,7842.0,6046.0,8552.0,1691.0,3540.0,1874.0] belongs to cluster 5
The data [2.0,3.0,4389.0,10940.0,10908.0,848.0,6728.0,993.0] belongs to cluster 4
The data [1.0,3.0,5065.0,5499.0,11055.0,364.0,3485.0,1063.0] belongs to cluster 4
The data [2.0,3.0,660.0,8494.0,18622.0,133.0,6740.0,776.0] belongs to cluster 4
The data [1.0,3.0,8861.0,3783.0,2223.0,633.0,1580.0,1521.0] belongs to cluster 5
The data [1.0,3.0,4456.0,5266.0,13227.0,25.0,6818.0,1393.0] belongs to cluster 4
The data [2.0,3.0,17063.0,4847.0,9053.0,1031.0,3415.0,1784.0] belongs to cluster 3
The data [1.0,3.0,26400.0,1377.0,4172.0,830.0,948.0,1218.0] belongs to cluster 3
The data [2.0,3.0,17565.0,3686.0,4657.0,1059.0,1803.0,668.0] belongs to cluster 3
The data [2.0,3.0,16980.0,2884.0,12232.0,874.0,3213.0,249.0] belongs to cluster 3
The data [1.0,3.0,11243.0,2408.0,2593.0,15348.0,108.0,1886.0] belongs to cluster 3
The data [1.0,3.0,13134.0,9347.0,14316.0,3141.0,5079.0,1894.0] belongs to cluster 4
The data [1.0,3.0,31012.0,16687.0,5429.0,15082.0,439.0,1163.0] belongs to cluster 0
The data [1.0,3.0,3047.0,5970.0,4910.0,2198.0,850.0,317.0] belongs to cluster 5
The data [1.0,3.0,8607.0,1750.0,3580.0,47.0,84.0,2501.0] belongs to cluster 5
The data [1.0,3.0,3097.0,4230.0,16483.0,575.0,241.0,2080.0] belongs to cluster 4
The data [1.0,3.0,8533.0,5506.0,5160.0,13486.0,1377.0,1498.0] belongs to cluster 5
The data [1.0,3.0,21117.0,1162.0,4754.0,269.0,1328.0,395.0] belongs to cluster 3
The data [1.0,3.0,1982.0,3218.0,1493.0,1541.0,356.0,1449.0] belongs to cluster 5
The data [1.0,3.0,16731.0,3922.0,7994.0,688.0,2371.0,838.0] belongs to cluster 3
The data [1.0,3.0,29703.0,12051.0,16027.0,13135.0,182.0,2204.0] belongs to cluster 0
The data [1.0,3.0,39228.0,1431.0,764.0,4510.0,93.0,2346.0] belongs to cluster 0
The data [2.0,3.0,14531.0,15488.0,30243.0,437.0,14841.0,1867.0] belongs to cluster 1
The data [1.0,3.0,10290.0,1981.0,2232.0,1038.0,168.0,2125.0] belongs to cluster 5
The data [1.0,3.0,2787.0,1698.0,2510.0,65.0,477.0,52.0] belongs to cluster 5
...

此部分内容为测试集的聚类结果。因为我们选了40个样本作为测试集,所以此部分输出的内容一共有40行。

5.后续工作

本次测试是在单机上做的demo测试,数据集比较小,运算过程也比较快。其实当数据量增大以后,基本过程跟这是类似的,只需要将input改为集群的数据路径,然后再写个简单的shell脚本,调用spark-submit,将任务提交到集群即可。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值