Spark MLlib Kmeans源代码解读(上)

Spark MLlib Kmeans源代码解读(上)

PS:第一次写博客,希望大家支持,谢谢。

基本原理:Kmeans算法的基本思想是:初始随机给定k个簇中心,按照最近邻的点将数据集中所包含的点分给不同的中心点,进而得到数据的分类。在分类过程中,需要不停的进行迭代,同时更新中心点的坐标,直到中心点的移动距离小于某一个给定的精度值。

Kmeans的聚类算法主要分为以下三个步骤:

如下图所示


我们可以看到,图中有ABCDE共五个点。而灰色点是我们所选择的中心点。我们设置的k=2.

  • 1 随机在图中取K(这里K=2)个种子点。在这里k=2
  • 2 然后对图中的所有点我们计算到这K个种子点的距离,假如点A离种子点S最近,那么A属于S点群。(上图中,我们可以看到A,B属于上面的种子点,C,D,E属于下面中部的种子点)

  • 3 接下来,需要重新计算新的中心点,这个中心点会作为新的中心点

  • 4 然后重复第2)和第3)步,直到,种子点没有移动(我们可以看到图中的第四步上面的种子点聚合了A,B,C,下面的种子点聚合了D,E)。

kmeans的demo展示:Kmeans Demo

初始化聚类中心点:

对于初始化聚类中心点,我们可以随机筛选,但是这个会导致我们计算的结果差异很大。而且很有可能不合实际。Kmeans++就是选择初始中心点的一个比较好的算法。其基本思想如下:

  • 1 从初始的数据点随机筛选一个点做为初始中心点。
  • 2 对于数据集中的每个点,计算它与聚类中心点的距离D(x)。
  • 3 选择一个新的数据点作为新的聚类中心,选择的原则是D(x)较大的点被选取作为聚类中心的概率较大
  • 4 重复2 和3 直到k个聚类中心筛选出来。
  • 5 得到所需要的聚类中心点。

Spark Kmeans 源代码:

Spark的 Kmeans 包含Kmeans 和KmeansModel两个类。首先需要生成聚类中心点,支持两种方式来生成中心点,随机方式和Kmeans++方式。这个接下来会解释。然后迭代计算样本中的中心点,迭代的时候首先计算每个中心点的样本值之和,同时计算每个中心点样本的数量。通过这个来求解新的中心点。并且判断中心点是否发生改变。

注意在Spark 计算中有一个快速计算距离的方式。大致过程如下:


首先定义一个lowerBoundOfSqDist公式。假设中心点center是(a1,b1),需要计算的点是(a2,b2),那么lowerBoundOfSqDist是:

lowerBoundOfSqDist=(a12+b12b12+b22)2=a12+b12+a22+b222(a12+b12)(a22+b22)

对比欧式距离:

EuclideanDist=(a1a2)2+(b1b2)2=a12+b12+a22+b222(a1a2+b1b2)


我们可以清晰的看到lowerBoundOfSqDist是小于或等于EuclideanDist的,在进行距离比较的时候,先计算比较容易计算的lowerBoundOfSqDist。(只需要计算center和point的范数).当lowerBoundOfSqDist不小于之前计算得到的最小距离(在mllib中用bestDistance表示),那么真实的欧氏距离也不需要计算了。省去了很多工作。

当lowerBoundOfSqDist小于EuclideanDist的时候,进行距离的计算。在计算时调用的是MLUtils的fastSquaredDistance方法。这个方法首先会计算一个精度,precision=2.0 * epsilon* sumSquaredNorm/(normDiff * normDiff+ epsilon).假设精度满足条件,则欧式距离等于为EuclideanDist= sumSquaredNorm-2.0 * v1.dot(v2). 其中sumSquaredNorm为 a12+a22 . 2.0 * v1.dot(v2)=2(a1a2+b1b2)

当精度不满足要求的时候再接着进行欧式距离的计算。

好了,我们看看Kmeans的源代码吧!

Kmeans源代码包含Kmeans半生对象,Kmeans主类,和KmeansModel类和一个VectorWithNorm类(用来封装向量和其范数)。
其大概的结构如下所示:


Kmeans伴生对象包含如下方法:

  • 1 train 静态方法,通过设置输入参数,初始化一个Kmeans类,并调用其中的run方法来执行Kmeans的聚类.
  • 2 findClosest静态方法, 找到当前点距离最近的聚类中心。返回的结果是一个元祖(Int,Doublt). 其中Int表示的是聚类中心点的索引,Double表示的是距离。
  • 3 pointCost方法,其内部调用了findClosest方法,返回的是findClosest方法元祖的第二个值,表示的是cost距离。
  • 4 fastSquaredDistance方法,其内部调用了MLUtils类的工具方法 fastSquaredDistance来快速的计算距离。
  • 5 validateInitMode, 检查初始化中心点的模式,是random模式还是Kmeans++模式。

Kmeans类包含如下方法:

  • 1 首先其包含一些构造方法和一些参数的set和get方法.
  • 2 包含有run方法,但是其内部调用的是runAlgorithm方法
  • 3 包含一个initRandom方法,用来随机的初始化中心点,和一个initKMeansParallel方法,用来通过kmeans++方法来初始化中心点。

KmeansModel类主要包含一个load方法和一个save方法,用来进行模型的加载和保存。还有就是一个predict方法,调用这个方法可以用来预测样本属于哪个类。


首先我们来看Kmeans伴生对象

Kmeans伴生对象是建立Kmeans模型的主要的入口。它主要定义了训练Kmeans模型的train方法。
这个train方法主要包含有以下的几个参数。

  • 1 data(数据样本,格式为RDD[Vector])
  • 2 k(聚类数量,聚类中心点的数量,默认为2个)
  • 3 maxIterations(最大的迭代次数)
  • 4 runs(并行度,表示的是再这样一个kmeans算法中同时运行多少个kmeans算法,而后求他的最优值)
  • 5 initializationMode(初始化的中心的模式,random和kmeans++(默认))
  • 6 seed 初始化的随机种子。
看看代码
//首先是定义的train方法类,
 def train(
      data: RDD[Vector], //数据样本RDD[Vector] 
      k: Int,            //聚类中心点的个数
      maxIterations: Int,  //最大的迭代次数
      runs: Int,           //并行度,默认是1
      initializationMode: String, //初始化中心点的模式,默认为kmeans++
      seed: Long): KMeansModel = { 
    new KMeans().setK(k)
      .setMaxIterations(maxIterations)
      .setRuns(runs)
      .setInitializationMode(initializationMode)
      .setSeed(seed)
      .run(data)  //可以看到它初始化了一个Kmeans类,然后调用了run方法。
  }

其他还有很多类似的重载的train方法就不一一列举了。最后我会把所有的源代码放上去。


让我们接下来看看第二个方法findClosest,spark在这里有一个优化,这个原理在上面讲过了

private[mllib] def findClosest(
      centers: TraversableOnce[VectorWithNorm],
      point: VectorWithNorm): (Int, Double) = {
    var bestDistance = Double.PositiveInfinity  // 首先初始化了一个最优distance
    var bestIndex = 0 //初始化一个最后的中心点的索引
    var i = 0 //这个i的意思值得是第几个中心点

    centers.foreach { center => //这个表示对于每个中心点

      var lowerBoundOfSqDist = center.norm - point.norm //首先会去计算两个向量之间的范数的差值

      lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist //计算其平方
      //注意$$lowerBoundOfSqDist=(\sqrt{a1^{2}+b1^{2}}-\sqrt{b1^{2}+b2^{2}})^{2}=a1^{2}+b1^{2}+a2^{2}+b2^{2}-2\sqrt{(a1^{2}+b1^{2})(a2^{2}+b2^{2})}$$
      //而实际的欧式距离公式如下:$$EuclideanDist=(a1-a2)^{2}+(b1-
b2)^2=a1^{2}+b1^{2}+a2^{2}+b2^{2}-2(a1a2+b1b2)$$

      //可以证明范数差之积是小于或等于欧式距离的,也就是,这里当假设范数差的乘积都大于bestDistance的话,则实际的欧式距离肯定更大于 
      //这样就可以不需要计算欧式距离,而当范数差乘积小于bestDistance的时候。会调用一个fastSquaredDistance来进行计算。
            if (lowerBoundOfSqDist < bestDistance) { //
        val distance: Double = fastSquaredDistance(center, point)
        if (distance < bestDistance) {
          bestDistance = distance
          bestIndex = i //可以看到当小于的话,会更新这个index。
        }
      }
      i += 1
    }
    (bestIndex, bestDistance) //最后这个方法返回一个元祖(index,bestDistance)。其中bestIndex表示的是最合适的聚类中心点的索引。
  }

下面我们来看看里面调用的这个fastSquaredDistance方法,这个方法里面调用了MLUtils的fastSquaredDistance方法

  private[clustering] def fastSquaredDistance(
      v1: VectorWithNorm,
      v2: VectorWithNorm): Double = {
    MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
  }

// 这个是MLUtils里面的方法

    private[mllib] def fastSquaredDistance(
      v1: Vector,  // 向量1和其范数
      norm1: Double,
      v2: Vector, //向量2和其范数
      norm2: Double, 
      precision: Double = 1e-6): Double = { //精度值,precision

    val n = v1.size
    require(v2.size == n)
    require(norm1 >= 0.0 && norm2 >= 0.0)

    //注意 a1*a1+b1*b1+a2*a2+b2*b2-2(sqrt((a1*a1+b1*b1)+(a2*a2+b2*b2)))
      //而实际的欧式距离公式如下:a1*a1+b1*b1+a2*a2+b2*b2- 2(a1*a2+b1*b2)

    val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 //计算范数的乘积
    val normDiff = norm1 - norm2 
    var sqDist = 0.0

     //接下来计算精度
    val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)

    //如果精度满足范围,则它的距离就近似的等于sumSquaredNorm-2*v1.dot(v2),依旧不去计算欧式距离
    if (precisionBound1 < precision) {
      sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)

      //如果两个向量都是稀疏向量的话。则先计算两个向量的点积
    } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
      val dotValue = dot(v1, v2)

      sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
      val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
        (sqDist + EPSILON)
      if (precisionBound2 > precision) {
        //h后面是直接来计算向量的欧式距离
        sqDist = Vectors.sqdist(v1, v2)
      }
    } else {
      sqDist = Vectors.sqdist(v1, v2)
    }
    sqDist //最后返回欧式距离
  }

接下来我们看看这个比较小的一个类,VectorWithNorm类。

private[clustering]

//表示将向量和范数封装为一个类。两个this,分别为不同的构造方法
//自定义向量格式: (向量,向量的二范数)

class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {

  def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))

  def this(array: Array[Double]) = this(Vectors.dense(array))

  /** Converts the vector to a dense vector. */
  //转变为密集向量
  def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}

好了接下来是最重要的一个类 Kmeans类。

首先,我们看一下这个this构造方法。

 //初始化的参数值,默认迭代的中心点的数目为2个,迭代的最大的次数为20次,
 迭代
   /**
   * 初始化的并行度默认为1个,即只计算一个kmeans task,
   * 如果改变这个值之后,则会同时计算很多的kmeans任务,
   *然后选择所有的kmeans任务中最好的。默认这个值为1个。
   *初始化的中心点的模式,默认为parallel
   *初始化迭代的步长,默认为5
   *初始的精度值,默认为1e-4
   * 初始化种子,默认为一个随机值
   */

def this() = 
this(2, 20, 1, 
KMeans.K_MEANS_PARALLEL, 
5, 1e-4, Utils.random.nextLong())

接下来有很多set和get方法,就不一一解释了

@Since("1.4.0")
  def getK: Int = k

  /**
   * Set the number of clusters to create (k). Default: 2.
   */
  @Since("0.8.0")
  def setK(k: Int): this.type = {
    this.k = k
    this
  }

  /**
   * Maximum number of iterations to run.
   */
  @Since("1.4.0")
  def getMaxIterations: Int = maxIterations

  /**
   * Set maximum number of iterations to run. Default: 20.
   */
  @Since("0.8.0")
  def setMaxIterations(maxIterations: Int): this.type = {
    this.maxIterations = maxIterations
    this
  }

我们来看看这个run方法.

def run(data: RDD[Vector]): KMeansModel = {

    //如果数据点并没有被cacheed的话,会弹出一个警告,说明需要被cache在内存,这样子
    //方便后面的迭代运算
    //注意spark的cache有好几层,比如说MEMORY_ONLY,MEMORY_AND_DISK。等等。默认是MEMORY_ONLY,这个其他博客有资料,可以查看,这里就不解释了

    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data is not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }

    // Compute squared norms and cache them.

    //计算数据点的二范数。
    val norms = data.map(Vectors.norm(_, 2.0))

    norms.persist() /// 将二范数缓存到内存
    ///利用拉链操作将数据和数据的范数连接起来。同时以一个新的内部类对象的格式存储
    //内部类的格式为向量和二范数的格式

    val zippedData = data.zip(norms).map { case (v, norm) =>
      new VectorWithNorm(v, norm)
    }

    //调用runalgorithm的方法来计算模型,最后去掉在内存中的缓存。
    //
    val model = runAlgorithm(zippedData)
    norms.unpersist()

    // Warn at the end of the run as well, for increased visibility.
    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data was not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }
    //然后返回模型
    model
  }

接下来这个方法就是Kmeans的核心方法runAlgorithm方法。我们来看看spark是如何实现的。


runAlgorithm方法

这个方法的分布式实现方式大致如下:


首先,通过初始化中心点的两个方法之一获取聚类中心点,然后将这个中心点广播到每一个RDD.然后使用mapPartitions算子,对于每一个分区,取得计算的聚类中心点,包括聚类的并行计算数量(runs),聚类的数量(k)等等参数。然后接下来根据每个分区的样本数据(即第二个points.foreach方法),计算与其相距最近的中心点,注意在统计完以后,计算了一个contribs,这个值表示的是(第i个并行度下,第j个聚类中心的聚类的点的距离的sum值,和点的数量的值),它的格式是这样的,((i,j)(第i个并行度,第j个聚类中心),(sums(i)(j),counts(i)(j))).它返回一个迭代器,然后在reduceByKey。
这个reduceByKey方法它的key指的就是(i,j)即相同的并行度和相同的第j个聚类中心点,但是由于是在不同的分区,所以说需要进行汇总。即在每一个分区的并行计算的结果进行汇总。最后调用collectAsMap算子进行输出。
对于输出的结果在进行判断,若两次迭代的中心点的差值小于一定的精度,则认为这个聚类成功。否则继续迭代。


具体的每一步的执行流程和代码讲解如下:

   private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {

    val sc = data.sparkContext  //获取这个rdd数据的sparkContext

    val initStartTime = System.nanoTime()

    // Only one run is allowed when initialModel is given
    //kMeans任务的并行度,默认为一个。表示总共有一个kmeans任务运行
    val numRuns = if (initialModel.nonEmpty) {  
      if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
      1
    } else {
      runs
    }

    val centers = initialModel match { ///初始化的模式,
      case Some(kMeansCenters) => {
        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
      }
      case None => {
        if (initializationMode == KMeans.RANDOM) { //如果是random 类型的,则调用initrandom方法
          initRandom(data)
        } else { //否则调用parallel方法, 即kmeans++方法
          initKMeansParallel(data)
        }
      }
    }
    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
    logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
      " seconds.")

      //初始化一个数组,大小为并行度,默认的结果为true。
    val active = Array.fill(numRuns)(true)
    //初始化一个数组用来存放每个并行kmeans任务下的cost,cost最小的会被选为合适的中心点
    val costs = Array.fill(numRuns)(0.0)

     //一个新的缓冲数组,值为0,1,2.....numRuns。依旧在当有多个并行度的情况下发挥作用,当runs值为1的时候
     //没有效果
    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
    var iteration = 0  //设置当前的迭代次数为0,后面每一次迭代,都会对这个值进行增加

    val iterationStartTime = System.nanoTime()

    // Execute iterations of Lloyd's algorithm until all runs have converged
    while (iteration < maxIterations && !activeRuns.isEmpty) {

      type WeightedPoint = (Vector, Long)  //将这个元祖作为一个整体,表示为weightedpoint

      def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {

        axpy(1.0, x._1, y._1) //常数乘以向量加另一个向量 y._1=x._1+y._1
        (y._1, x._2 + y._2)  //返回一个元祖
      }

      //存储的是每一个并行度下的聚类中心
      val activeCenters = activeRuns.map(r => centers(r)).toArray
       //存储的是每一个并行度下的cost的累加值的初始值,其初始值为0.0. 每个并行计算的kmeans任务会对这个值进行累加。进而计算cost
      val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
       //将中心点数组广播到每一个rdd
      val bcActiveCenters = sc.broadcast(activeCenters)

      // Find the sum and count of points mapping to each center
      //对于每个分区计算局中心点最近的点
      val totalContribs = data.mapPartitions { points =>

         //计算每一个中心点的样本,对每一个中心店的样本进行累加和计算
         //runs代表并行度,k代表中心点的个数,sums代表中心点的累加值,
         //counts代表的是中心点的样本计数,contribs代表的是(并行度i,中心j),(中心j的样本之和,中心j的样本个数)
         //找到点与所有的聚类中心点最近的一个中心

        val thisActiveCenters = bcActiveCenters.value //表示的是聚类中心点的值
        val runs = thisActiveCenters.length   //表示的是kmeans任务的并行度,即同时进行多少个kmeans任务 
        val k = thisActiveCenters(0).length   //表示的是中心点的个数。
        val dims = thisActiveCenters(0)(0).vector.size  //表示的是中心点的维度的值 

        val sums = Array.fill(runs, k)(Vectors.zeros(dims))  //sums表示在当前第i个并行度的情况下,第k个中心点的其他个点距离这个点的
        //距离的和
        val counts = Array.fill(runs, k)(0L) //counts,表示在当前第i个并行度的情况下,第k个中心点的这个类的点的数量

        points.foreach { point =>   //针对每一个样本,并行的计算
          (0 until runs).foreach { i =>

            val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point) //这个表示计算的样本属于哪个中心点
            costAccums(i) += cost  //并行度下的cost之和。
            val sum = sums(i)(bestCenter) 
            axpy(1.0, point.vector, sum) //sum=sum+point
            counts(i)(bestCenter) += 1  // counts的数目累加。
          }
        }

        //每一个聚类下样本的向量和,样本点的数目的和counts,还有就是i和j,其中i表示的是并行度,j表示的是聚类中心
        val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
          ((i, j), (sums(i)(j), counts(i)(j)))
        }
        contribs.iterator  //表示的是对于刚才的一个mappartitions的操作之后返回一个iterator
      }.reduceByKey(mergeContribs).collectAsMap()  //调用刚才的mergeContribs函数 
      //这一步很关键。在mappartition方法完成了之后,调用了reducebykey方法,对于有相同key的数据进行汇总。
      //对属于同一中心点下的样本向量之和和样本数量进行累加的操作
      bcActiveCenters.unpersist(blocking = false)

      // Update the cluster centers and costs for each active run
      //更新中心点,更新中心点=sum/count
      //判断newCenter和centers之间的距离是否是大于epsilon的平方
      for ((run, i) <- activeRuns.zipWithIndex) {
        var changed = false
        var j = 0   //j表示的是中心点的数目,i表示的是第i个并行度
        while (j < k) {
          val (sum, count) = totalContribs((i, j)) //获取到当前的sum和count,求解信的中心点
          if (count != 0) {
            scal(1.0 / count, sum) //sum=sum/count。这个是一个向量除法运算
            val newCenter = new VectorWithNorm(sum)

            if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
              //如果大于精度值的话,则认为中心点改变了。则更新中心点
              changed = true
            }
            centers(run)(j) = newCenter
          }
          j += 1
        }
        if (!changed) { //如果整体没有改变,则完成
          active(run) = false
          logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations")
        }
        costs(run) = costAccums(i).value
      }

      activeRuns = activeRuns.filter(active(_)) //过滤掉已经收敛的并行计算的
      iteration += 1
    }

    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
    logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")

    if (iteration == maxIterations) {
      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
    } else {
      logInfo(s"KMeans converged in $iteration iterations.")
    }

    val (minCost, bestRun) = costs.zipWithIndex.min  //找到最小的值作为最优的中心点

    logInfo(s"The cost for the best run is $minCost.")

    new KMeansModel(centers(bestRun).map(_.vector))  //返回一个新的kmeansmodel
  }

下面我们来看看初始化中心点的两个方法。
第一个是random方法,它的思想比较简单,就是随机的从数据集中抽取出k个中心点。

/**
   * Initialize `runs` sets of cluster centers at random.
   *中心点的初始化,目前支持的是random和kmeans++方法
   */
  private def initRandom(data: RDD[VectorWithNorm]) //随机选出中心点
  : Array[Array[VectorWithNorm]] = {

    //随机抽取出中心点,其中runs为并行度,k为中心点的个数。
    val sample = data.takeSample(true, runs * k, new       XORShiftRandom(this.seed).nextInt()).toSeq

    //创建一个新的数组,数组里面的元素是VectorWithNorm格式
    Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
      //返回的格式为中心点和中心点的范数
      new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
    }.toArray)
  }

第二个方法是通过kmeans++方式来初始化中心点,原文可以再这个链接找到
http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.


private def initKMeansParallel(data: RDD[VectorWithNorm])
  : Array[Array[VectorWithNorm]] = {

    //初始化中心及costs,tabluate方法返回一个数组,长度为runs。
    val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
    var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))

    // Initialize each run's first center to a random point.
    val seed = new XORShiftRandom(this.seed).nextInt()


    //第一步
    //初始化第一个中心点,随机
    val sample = data.takeSample(true, runs, seed).toSeq //随机筛选出一些中心点。 
    val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) //获取一个长度为为runs的中心点数组 

    /** Merges new centers to centers. */
    def mergeNewCenters(): Unit = { //合并新的中心点到中心
      var r = 0
      while (r < runs) {
        centers(r) ++= newCenters(r)
        newCenters(r).clear()
        r += 1
      }
    }


    //第二步,通过已知的中心点,循环迭代求得其他的中心点。

    //每次迭代的过程,抽取2*k个样本,每次迭代计算样本点与中心店的距离
    var step = 0
    while (step < initializationSteps) {
      val bcNewCenters = data.context.broadcast(newCenters) //新的中心点
      val preCosts = costs

      //这个cost表示的是每个点距离最近中心点的代价。
     //j将数据点和cost通过拉链操作连接在一起,返回一个(point,cost)
     //这个会在下一步的时候通过调用math。min方法,找出cost(r) 和通过kmeans的pointcost方法返回的最小的cost值
     //并将这个值更新到costs数组。同时将这个costs数组cache到缓存。
      costs = data.zip(preCosts).map { case (point, cost) =>
          Array.tabulate(runs) { r =>
            math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
          }
        }.persist(StorageLevel.MEMORY_AND_DISK)





      //接下来聚合costs,聚合后的返回值,为一个Array,其内部的元素为Double类型。
      //注意接下来这个aggregate算子,它接收三个参数,第一个参数接收一个初始值,这个初始值,首先会作用到每个分区,
      //应用于每个分区的函数是接下来定义的第一个SeqOp函数。
      //这个参数会在每个分区发挥作用。
      //接下来有一个combOp函数,这个函数会在对每一个聚合后的结果发挥作用。相当于前面函数作用后的结果,会在后面继续发挥作用.
      //需要注意的是,aggregate算子,的初始参数在第二个函数页发挥作用。而aggregateByKey算子不会发挥作用。

      val sumCosts = costs
        .aggregate(new Array[Double](runs))(
          //接下来计算的方式如下所示,由于有一个并行度,每一个并行度会有一个自己的costs数组,所以计算costs数组的时候会
          // 分开对每一个costs数组进行计算。第一个函数的意思是:合并第二个参数v(也是一个costs值)到第一个s里面。
         //然后返回一个s数组。相当于每一个分区都会有这个s数组


          seqOp = (s, v) => { //这个表示分区内迭代
            // s += v
            var r = 0
            while (r < runs) {
              s(r) += v(r)
              r += 1
            }
            s  
          },




          //接下来对于不同的分区,计算s数组的值的和。和刚才一样,因为有一个并行度,所以默认的是对每一个并行度下的cost数组进行计算。
          //然后返回这个s0。s0是一个数组。从0~runs的一个double类型的数组。每一个对应的元素包含的是对应的在第r次并行度下的cost值

          combOp = (s0, s1) => { //这个表示分区间合并
            // s0 += s1
            var r = 0
            while (r < runs) {
              s0(r) += s1(r)
              r += 1
            }
            s0
          }
        )

      bcNewCenters.unpersist(blocking = false) // 去掉在内存中的缓存
      preCosts.unpersist(blocking = false)

      //选择满足概率得点。      
      val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>

        val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
        pointsWithCosts.flatMap { case (p, c) =>
          val rs = (0 until runs).filter { r =>
            rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
          }
          if (rs.length > 0) Some(p, rs) else None
        }
      }.collect()

      mergeNewCenters()
      chosen.foreach { case (p, rs) =>
        rs.foreach(newCenters(_) += p.toDense)
      }
      step += 1
    }

    mergeNewCenters()
    costs.unpersist(blocking = false)


//第三步,求得最终的k个点
//通过以上步骤求得的候选中心点的个数可能会多于`k`个,这样怎么办呢?我们给每个中心点赋//一个权重,权重值是数据集中属于该中心点所在类别的数据点的个数。
//然后我们使用本地`k-means++`来得到这`k`个初始化点。具体的实现代码如下:
    val bcCenters = data.context.broadcast(centers)
    val weightMap = data.flatMap { p =>
      Iterator.tabulate(runs) { r =>
        ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
      }
    }.reduceByKey(_ + _).collectAsMap()

    bcCenters.unpersist(blocking = false)

    val finalCenters = (0 until runs).par.map { r =>
      val myCenters = centers(r).toArray
      val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
      LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
    }

    finalCenters.toArray
  }
}

接下来我们来看看KMeansModel类


KMeansModel类包含了中心点的向量,KmeansModel包含了预测,保存模型和加载模型的相关方法。其中predict方法调用了KMeans的findClosest方法来判断这个向量数据哪个中心点。

这个是其中的predict方法


 //predict方法通过findClostest方法来查找最近的中心点。
  //首先将和范数在一起的向量值广播到每一个rdd去。然后通过findClosest方法来查找属于哪个中心点

  def predict(points: RDD[Vector]): RDD[Int] = {
    val centersWithNorm = clusterCentersWithNorm 
    val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
    points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
  }

其中clusterCentersWithNorm表示的是将每个中心点向量以VectorWithNorm的形式返回。

private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
    clusterCenters.map(new VectorWithNorm(_))

computeCost方法,用于计算样本点到最近中心点的距离平方之和。其内部调用了pointCost方法,返回的是一个sum值。

 //计算cost,首先也是广播中心点的向量和范数到每个rdd,然后接下来计算到每个中心点的距离的和s
  def computeCost(data: RDD[Vector]): Double = {
    val centersWithNorm = clusterCentersWithNorm
    val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
    data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
  }

还有一个就是保存模型的方法和加载模型的方法。

 //保存模型的方法
     def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
      val sqlContext = SQLContext.getOrCreate(sc)
      import sqlContext.implicits._
      val metadata = compact(render(
        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
        Cluster(id, point)
      }.toDF()
      dataRDD.write.parquet(Loader.dataPath(path))
    }
  def load(sc: SparkContext, path: String): KMeansModel = {
      implicit val formats = DefaultFormats
      val sqlContext = SQLContext.getOrCreate(sc)
      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
      assert(className == thisClassName)
      assert(formatVersion == thisFormatVersion)
      val k = (metadata \ "k").extract[Int]
      val centroids = sqlContext.read.parquet(Loader.dataPath(path))
      Loader.checkSchema[Cluster](centroids.schema)
      val localCentroids = centroids.map(Cluster.apply).collect()
      assert(k == localCentroids.size)
      new KMeansModel(localCentroids.sortBy(_.id).map(_.point))
    }

好了以上就是关于Spark MLlib Kmeans的源代码的解析。

参考文献和资料

  • [1 Bahman Bahmani,Benjamin Moseley,Andrea Vattani.Scalable K-Means++](papers/Scalable K-Means++.pdf)
  • [2 David Arthur and Sergei Vassilvitskii.k-means++: The Advantages of Careful Seeding](papers/k-means++: The Advantages of Careful Seeding.pdf)
  • [ 3 深入浅出Kmeans] (http://www.csdn.net/article/2012-07-03/2807073-k-means)
  • 4 Spark机器学习。
  • 5 本代码来自于 Spark1.6.1
  • 10
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值